{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2.2 决策树"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "决策树是一种通过**树形结构**进行分类的方法。在决策树中，树形结构中每个节点表示对分类目标在属性上的一个判断，每个分支代表基于该属性做出的一个判断，最后树形结构中每个叶子结点代表一种分类结果。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.2.1 决策树分类概念"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**数据**: 游乐场经营者提供**天气情况**（如晴、雨、多云）、**温度高低**、**湿度大小**、**风力强弱**等气象特点以及游客当天是否前往游乐场。\n",
    "\n",
    "**目标**: 预测游客是否来游乐场游玩。\n",
    "\n",
    "\n",
    "|序号|天气|温度（℃）|湿度|是否有风|是（1）否（0）前往游乐场|\n",
    "|:--:|:--:|:--:|:--:|:--:|:--:|\n",
    "|1|晴|29|85|否|0|\n",
    "|2|晴|26|88|是|0|\n",
    "|3|多云|28|78|否|1\n",
    "|4|雨|21|96|否|1|\n",
    "|5|雨|20|80|否|1|\n",
    "|6|雨|18|70|是|0|\n",
    "|7|多云|18|65|是|1|\n",
    "|8|晴|22|90|否|0|\n",
    "|9|晴|21|68|否|1|\n",
    "|10|雨|24|80|否|1|\n",
    "|11|晴|24|63|是|1|\n",
    "|12|多云|22|90|是|1|\n",
    "|13|多云|27|75|否|1|\n",
    "|14|雨|21|80|是|0|\n",
    "\n",
    "根据上表，绘制如图所示的决策树：\n",
    "\n",
    "<img src=\"http://imgbed.momodel.cn/决策树11.png\" width=500px>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "第一层是天气状况，具有雨、多云和晴三种属性取值。\n",
    "+ 多云:  样本子集是 { 3, 7, 12, 13 } ，仅有“前往游乐场游玩”一个类别，即肯定去游乐场。  \n",
    "  \n",
    "  \n",
    "+ 晴:  样本子集是 { 1, 2, 8, 9, 11 }\n",
    "    + 湿度大于 75：样本子集为 { 1, 2, 8 }，不前往游乐场。\n",
    "    + 湿度不大于 75：样本子集 { 9, 11 }，前往游乐场。\n",
    "    \n",
    "    \n",
    "+ 雨：样本子集为 { 4, 5, 6, 10, 14 }\n",
    "    + 有风：样本子集 { 6, 14 },不去游乐场。\n",
    "    + 无风：样本子集 { 4, 5, 10 }，前往游乐场。\n",
    " "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "由上面的例子可以看到，构建决策树的过程就是：\n",
    "1. 选择一个属性值；\n",
    "2. 基于该属性对样本集进行划分；\n",
    "3. 重复步骤 1 和 2 直到最后所得划分结果中每个样本为同一类别。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import math\n",
    "from math import log\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 原始数据\n",
    "datasets = [\n",
    "    ['晴',29,85,'否','0'],\n",
    "    ['晴',26,88,'是','0'],\n",
    "    ['多云',28,78,'否','1'],\n",
    "    ['雨',21,96,'否','1'],\n",
    "    ['雨',20,80,'否','1'],\n",
    "    ['雨',18,70,'是','0'],\n",
    "    ['多云',18,65,'是','1'],\n",
    "    ['晴',22,90,'否','0'],\n",
    "    ['晴',21,68,'否','1'],\n",
    "    ['雨',24,80,'否','1'],\n",
    "    ['晴',24,63,'是','1'],\n",
    "    ['多云',22,90,'是','1'],\n",
    "    ['多云',27,75,'否','1'],\n",
    "    ['雨',21,80,'是','0']\n",
    "]\n",
    "\n",
    "# 数据的列名\n",
    "labels = ['天气','温度','湿度','是否有风','是否前往游乐场']\n",
    "\n",
    "# 将湿度大小分为大于 75 和小于等于 75 这两个属性值，\n",
    "# 将温度大小分为大于 26 和小于等于 26 这两个属性值\n",
    "for i in range(len(datasets)):\n",
    "    if datasets[i][2] > 75:\n",
    "        datasets[i][2] = '>75'\n",
    "    else:\n",
    "        datasets[i][2] = '<=75'\n",
    "    if datasets[i][1] > 26:\n",
    "        datasets[i][1] = '>26'\n",
    "    else:\n",
    "        datasets[i][1] = '<=26'\n",
    "\n",
    "# 构建 dataframe 并查看数据\n",
    "df = pd.DataFrame(datasets, columns=labels)\n",
    "df\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.2.2 构建决策树  \n",
    "\n",
    "**信息增益**用来衡量样本集合复杂度（不确定性）所减少的程度。  \n",
    "\n",
    "**信息熵**用来度量信息量的大小。从信息论的角度来看，对信息的度量等于计算信息不确定性的多少。  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "假设有 $K$ 个信息，其组成了集合样本 $D$ ，记第 $k$ 个信息发生的概率为$P_k(1≤k≤K)$。  \n",
    "这 $K$ 个信息的信息熵：  \n",
    "$$E(D)=-\\sum_{k=1}^{K}p_k log_{2} p_k$$\n",
    "\n",
    "需要指出：**所有 $p_k$ 累加起来的和为1**。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_entropy(total_num, count_dict):\n",
    "    \"\"\"\n",
    "    计算信息熵\n",
    "    :param total_num: 总样本数\n",
    "    :param count_dict: 每类样本及其对应数目的字典\n",
    "    :return: 信息熵\n",
    "    \"\"\"\n",
    "    # 使用信息熵公式计算\n",
    "    ent = -sum([(p / total_num) * log(p / total_num, 2) for p in count_dict.values() if p != 0])\n",
    "    # 避免 print 显示异常\n",
    "    if ent == 0:\n",
    "        ent = 0\n",
    "    # 返回信息熵，精确到小数点后 4 位\n",
    "    return round(ent, 4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "现在用**熵**来构建决策树。数据中 14 个样本分为 “游客来游乐场（ 9 个样本）” 和 “游客不来游乐场（ 5 个样本）” 两个类别，即 K = 2。\n",
    "\n",
    "记 “游客来游乐场” 和 “游客不来游乐场” 的概率分别为 $p_1$ 和 $p_2$ ,显然 $p_1=\\frac{9}{14}$，$p_1=\\frac{5}{14}$ ，则这 14 个样本所蕴含的信息熵：\n",
    "\n",
    "$$E(D)=-\\sum_{k=1}^{2}p_{k}log_{2}{p_k}=-(\\frac{9}{14}×log_{2}{\\frac{9}{14}}+\\frac{5}{14}×log_{2}{\\frac{5}{14}})=0.940$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以用下面这种方式对 dataframe 的数据按条件进行筛选。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 例如：按 是否前往游乐场==0 进行筛选\n",
    "df[df['是否前往游乐场']=='0']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用上面的方法，可以得到计算信息熵所需的总样本数，以及每类样本及其对应数目的字典，然后计算信息熵。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 总样本数\n",
    "total_num = df.shape[0]\n",
    "\n",
    "# 每类样本及其对应数目的字典\n",
    "count_dict = {'前往':df[df['是否前往游乐场']=='1'].shape[0], '不前往':df[df['是否前往游乐场']=='1'].shape[1]}\n",
    "\n",
    "# 计算信息熵\n",
    "entropy = calc_entropy(total_num, count_dict)\n",
    "entropy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**计算天气状况所对应的信息熵**：  \n",
    "天气状况的三个属性记为 $a_0=“晴”$ ，$a_1=“多云”$ ，$a_2=“雨”$ ，  \n",
    "属性取值为 $a_i$ 对应分支节点所包含子样本集记为 $D_i$ ,该子样本集包含样本数量记为 $|D_i|$ 。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "|天气属性取值$a_i$|“晴”|“多云”|“雨”|\n",
    "|:--:|:--:|:--:|:--:|\n",
    "|对应样本数$|D_i|$|5|4|5|\n",
    "|正负样本数量|（2+，3-）|（4+，0-）|（3+，2-）|\n",
    "\n",
    "计算天气状况每个属性值的信息熵：\n",
    "\n",
    "$“晴”：E(D_0)=-(\\frac{2}{5}×log_{2}{\\frac{2}{5}}+\\frac{3}{5}×log_{2}{\\frac{3}{5}})=0.971$\n",
    "\n",
    "$“多云”：E(D_1)=-(\\frac{4}{4}×log_{2}{\\frac{4}{4}})=0$\n",
    "\n",
    "$“雨”：E(D_2)=-(\\frac{3}{5}×log_{2}{\\frac{3}{5}}+\\frac{2}{5}×log_{2}{\\frac{2}{5}})=0.971$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以使用下面的写法，对 dataframe 进行多个条件的筛选。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 筛选出 天气为晴并且去游乐场的样本数据\n",
    "df[(df['天气']=='晴') & (df['是否前往游乐场']=='1')]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 天气为晴的总天数\n",
    "total_num_sun = df[df['天气']=='晴'].shape[0]\n",
    "\n",
    "# 天气为晴时，去游乐场和不去游乐场的人数\n",
    "count_dict_sun = {'前往':df[(df['天气']=='晴') & (df['是否前往游乐场']=='1')].shape[0], \n",
    "              '不前往':df[(df['天气']=='晴') & (df['是否前往游乐场']=='0')].shape[0]}\n",
    "print(count_dict_sun)\n",
    "\n",
    "# 计算天气-晴 的信息熵\n",
    "ent_sun = calc_entropy(total_num_sun, count_dict_sun)\n",
    "print('天气-晴 的信息熵为：%s' % ent_sun)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 天气为多云的总天数\n",
    "total_num_cloud = df[df['天气']=='多云'].shape[0]\n",
    "\n",
    "# 天气为多云时，去游乐场和不去游乐场的人数\n",
    "count_dict_cloud = {'前往':df[(df['天气']=='多云') & (df['是否前往游乐场']=='1')].shape[0], \n",
    "              '不前往':df[(df['天气']=='多云') & (df['是否前往游乐场']=='0')].shape[0]}\n",
    "print(count_dict_cloud)\n",
    "\n",
    "# 计算天气-多云 的信息熵\n",
    "ent_cloud = calc_entropy(total_num_cloud, count_dict_cloud)\n",
    "print('天气-多云 的信息熵为：%s' % ent_cloud)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 天气为雨的总天数\n",
    "total_num_rain = df[df['天气']=='雨'].shape[0]\n",
    "\n",
    "# 天气为雨时，去游乐场和不去游乐场的人数\n",
    "count_dict_rain = {'前往':df[(df['天气']=='雨') & (df['是否前往游乐场']=='1')].shape[0], \n",
    "              '不前往':df[(df['天气']=='雨') & (df['是否前往游乐场']=='0')].shape[0]}\n",
    "print(count_dict_rain)\n",
    "\n",
    "# 计算天气-雨 的信息熵\n",
    "ent_rain = calc_entropy(total_num_rain, count_dict_rain)\n",
    "print('天气-雨 的信息熵为：%s' % ent_rain)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "计算天气状况的信息增益:  \n",
    "$$Gain(D,A)=E(D)-\\sum_{i}^{n}\\frac{|D_i|}{D}E(D)$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "其中，$A=“天气状况”$。于是天气状况这一气象特点的信息增益为：\n",
    "$$Gain(D,天气)=0.940-（\\frac{5}{14}×0.971+\\frac{4}{14}×0+\\frac{5}{14}×0.971）=0.246$$\n",
    "\n",
    "同理可以计算温度高低、湿度大小、风力强弱三个气象特点的信息增益。  \n",
    "通常情况下，某个分支的信息增益越大，则该分支对样本集划分所获得的“纯度”越大，信息不确定性减少的程度越大。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "假设有 $K$ 个信息，其组成了集合样本 $D$ ，记第 $k$ 个信息发生的概率为$P_k(1≤k≤K)$。  \n",
    "这 $K$ 个信息的信息熵：  \n",
    "$$E(D)=-\\sum_{k=1}^{K}p_k log_{2} p_k$$\n",
    "\n",
    "需要指出：**所有 $p_k$ 累加起来的和为1**。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用上面的公式计算信息增益。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 信息增益\n",
    "gain = entropy - (total_num_sun/total_num*ent_sun + \n",
    "                  total_num_cloud/total_num*ent_cloud + \n",
    "                  total_num_rain/total_num*ent_rain)\n",
    "gain"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 思考与练习    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. 分别将天气状况、温度高低、湿度大小、风力强弱作为分支点来构建决策树，查看信息增益\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 温度 >26 的信息熵\n",
    "total_num_temp_high = df[df['温度']=='>26'].shape[0]\n",
    "count_dict_temp_high = {'前往':df[(df['温度']=='>26') & (df['是否前往游乐场']=='1')].shape[0], \n",
    "                   '不前往':df[(df['温度']=='>26') & (df['是否前往游乐场']=='0')].shape[0]}\n",
    "print(count_dict_temp_high)\n",
    "ent_temp_high = calc_entropy(total_num_temp_high, count_dict_temp_high)\n",
    "print('温度 >26 的信息熵为：%s' % ent_temp_high)\n",
    "\n",
    "# 温度 <=26 的信息熵\n",
    "total_num_temp_low = df[df['温度']=='<=26'].shape[0]\n",
    "count_dict_temp_low = {'前往':df[(df['温度']=='<=26') & (df['是否前往游乐场']=='1')].shape[0], \n",
    "                     '不前往':df[(df['温度']=='<=26') & (df['是否前往游乐场']=='0')].shape[0]}\n",
    "print(count_dict_temp_low)\n",
    "ent_temp_low = calc_entropy(total_num_temp_low, count_dict_temp_low)\n",
    "print('温度 <=26 的信息熵为：%s' % ent_temp_low)\n",
    "\n",
    "# 如果按照温度高低进行划分，则对应的信息增益为\n",
    "gain = entropy - (total_num_temp_high/total_num*ent_temp_high + \n",
    "                  total_num_temp_low/total_num*ent_temp_low)\n",
    "print('按照温度高低进行划分的信息增益为：%s' % gain)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 湿度 >75 的信息熵\n",
    "total_num_hum_high = df[df['湿度']=='>75'].shape[0]\n",
    "count_dict_hum_high = {'前往':df[(df['湿度']=='>75') & (df['是否前往游乐场']=='1')].shape[0], \n",
    "                   '不前往':df[(df['湿度']=='>75') & (df['是否前往游乐场']=='0')].shape[0]}\n",
    "print(count_dict_hum_high)\n",
    "ent_hum_high = calc_entropy(total_num_hum_high, count_dict_hum_high)\n",
    "print('湿度 >75 的信息熵为：%s' % ent_hum_high)\n",
    "\n",
    "# 湿度 <=75 的信息熵\n",
    "total_num_hum_low = df[df['湿度']=='<=75'].shape[0]\n",
    "count_dict_hum_low = {'前往':df[(df['湿度']=='<=75') & (df['是否前往游乐场']=='1')].shape[0], \n",
    "                     '不前往':df[(df['湿度']=='<=75') & (df['是否前往游乐场']=='0')].shape[0]}\n",
    "print(count_dict_hum_low)\n",
    "ent_hum_low = calc_entropy(total_num_hum_low, count_dict_hum_low)\n",
    "print('湿度 <=75 的信息熵为：%s' % ent_hum_low)\n",
    "\n",
    "# 如果按照湿度高低进行划分，则对应的信息增益为\n",
    "gain = entropy - (total_num_hum_high/total_num*ent_hum_high + \n",
    "                  total_num_hum_low/total_num*ent_hum_low)\n",
    "print('按照湿度高低进行划分的信息增益为：%s' % gain)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 有风 的信息熵\n",
    "total_num_wind = df[df['是否有风']=='是'].shape[0]\n",
    "count_dict_wind = {'前往':df[(df['是否有风']=='是') & (df['是否前往游乐场']=='1')].shape[0], \n",
    "              '不前往':df[(df['是否有风']=='是') & (df['是否前往游乐场']=='0')].shape[0]}\n",
    "print(count_dict_wind)\n",
    "ent_wind = calc_entropy(total_num_wind, count_dict_wind)\n",
    "print('有风 的信息熵为：%s' % ent_wind)\n",
    "\n",
    "# 无风 的信息熵\n",
    "total_num_nowind = df[df['是否有风']=='否'].shape[0]\n",
    "count_dict_nowind = {'前往':df[(df['是否有风']=='否') & (df['是否前往游乐场']=='1')].shape[0], \n",
    "              '不前往':df[(df['是否有风']=='否') & (df['是否前往游乐场']=='0')].shape[0]}\n",
    "print(count_dict_nowind)\n",
    "ent_nowind = calc_entropy(total_num_nowind, count_dict_nowind)\n",
    "print('无风 的信息熵为：%s' % ent_nowind)\n",
    "\n",
    "# 如果按照是否有风进行划分，则对应的信息增益为\n",
    "gain = entropy - (total_num_wind/total_num*ent_wind + \n",
    "                  total_num_nowind/total_num*ent_nowind)\n",
    "print('按照是否有风进行划分的信息增益为：%s' % gain)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. 每朵鸢尾花有萼片长度、萼片宽度、花瓣长度、花瓣宽度四个特征。现在需要根据这四个特征将鸢尾花分为杂色鸢尾花、维吉尼亚鸢尾和山鸢尾三类，试构造决策树进行分类。\n",
    "\n",
    "|序号|萼片长度|萼片宽度|花瓣长度|花瓣宽度|种类|\n",
    "|:--:|:--:|:--:|:--:|:--:|:--:|\n",
    "|1|5.0|2.0|3.5|1.0|杂色鸢尾|\n",
    "|2|6.0|2.2|5.0|1.5|维吉尼亚鸢尾|\n",
    "|3|6.0|2.2|4.0|1.0|杂色鸢尾|\n",
    "|4|6.2|2.2|4.5|1.5|杂色鸢尾|\n",
    "|5|4.5|2.3|1.3|0.3|山鸢尾|"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "观察上表中的五笔数据，我们可以看到 杂色鸢尾 和 维吉尼亚鸢尾 的花瓣宽度明显大于山鸢尾，所以可以通过判断花瓣宽度是否大于 0.7，来将山鸢尾区从其他两种鸢尾中区分出来。\n",
    "\n",
    "同时，杂色鸢尾 和 维吉尼亚鸢尾 的花瓣长度明显大于山鸢尾，所以也可以通过判断花瓣长度是否大于 2.4，来将山鸢尾区从其他两种鸢尾中区分出来。\n",
    "\n",
    "然后我们观察到 维吉尼亚鸢尾的花瓣长度明显大于杂色鸢尾，所以可以通过判断花瓣长度是否大于 4.75，来将杂色鸢尾和维吉尼亚鸢尾区分出来。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "实际上是否如此呢？"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上面的表格只是 Iris 数据集的一小部分，完整的数据集包含 150 个数据样本，分为 3 类，每类 50 个数据，每个数据包含 4 个属性。即花萼长度，花萼宽度，花瓣长度，花瓣宽度4个属性。\n",
    "\n",
    "我们使用 sklearn 工具包来构建决策树模型，先导入数据集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.datasets import load_iris\n",
    "# 加载数据集\n",
    "iris = load_iris()\n",
    "# 查看 label\n",
    "print(list(iris.target_names))\n",
    "# 查看 feature\n",
    "print(iris.feature_names)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "setosa 是山鸢尾，versicolor是杂色鸢尾，virginica是维吉尼亚鸢尾。\n",
    "\n",
    "sepal length， sepal width，petal length，petal width 分别是萼片长度，萼片宽度，花瓣长度，花瓣宽度。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后进行训练集和测试集的切分。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "# 载入数据\n",
    "X, y = load_iris(return_X_y=True)\n",
    "# 切分训练集合测试集\n",
    "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来，我们在训练集数据上训练决策树模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn import tree\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "# 初始化模型，可以调整 max_depth 来观察模型的表现\n",
    "clf = tree.DecisionTreeClassifier(random_state=42, max_depth=2)\n",
    "# 训练模型\n",
    "clf = clf.fit(X_train, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以使用 graphviz 包来展示构建好的决策树。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import graphviz\n",
    "feature_names = ['萼片长度','萼片宽度','花瓣长度','花瓣宽度']\n",
    "target_names = ['山鸢尾', '杂色鸢尾', '维吉尼亚鸢尾']\n",
    "# 可视化生成的决策树\n",
    "dot_data = tree.export_graphviz(clf, out_file=None, \n",
    "                     feature_names=feature_names,  \n",
    "                     class_names=target_names,  \n",
    "                     filled=True, rounded=True,  \n",
    "                     special_characters=True)  \n",
    "graph = graphviz.Source(dot_data)  \n",
    "graph "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们看模型在测试集上的表现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score\n",
    "y_test_predict = clf.predict(X_test)\n",
    "accuracy_score(y_test,y_test_predict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 实践与体验\n",
    "\n",
    "**计算文章的信息熵**\n",
    "\n",
    "收集中英文对照的短文，在计算短文内中文单词和英文单词出现概率基础上，计算该两篇短文的信息熵，比较中文短文信息熵和英文短文信息熵的大小。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "首先定义一个方法来辅助读取文件的内容。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def read_file(path):\n",
    "    \"\"\"\n",
    "    读取某文件的内容\n",
    "    :param path: 文件的路径\n",
    "    :return: 文件的内容\n",
    "    \"\"\"\n",
    "    contents = \"\"\n",
    "    with open(path) as f:\n",
    "        # 读取每一行的内容\n",
    "        for line in f.readlines():\n",
    "            contents += line\n",
    "    return contents"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用上面定义的方法读取英文短文及其对应的中文短文。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取英文短文\n",
    "en_essay = read_file('essay3_en.txt')\n",
    "# 读取中文短文\n",
    "ch_essay = read_file('essay3_ch.txt')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "处理文本，统计单词出现的概率，并计算信息熵。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "import re\n",
    "\n",
    "\n",
    "def cal_essay_entropy(essay, split_by=None):\n",
    "    \"\"\"\n",
    "    计算文章的信息熵\n",
    "    :param essay: 文章内容\n",
    "    :param split_by: 切分方式，对于中文文章，不需传入，按字符切分，\n",
    "                     对于英文文章，需传入空格字符来进行切分\n",
    "    :return: 文章的信息熵\n",
    "    \"\"\"\n",
    "    # 把英文全部转为小写\n",
    "    essay = essay.lower()\n",
    "    # 去除标点符号\n",
    "    essay = re.sub(\n",
    "        \"[\\f+\\n+\\r+\\t+\\v+\\?\\.\\!\\/_,$%^*(+\\\"\\']+|[+——！，。？、~@#《》￥%……&*（）]\", \"\",\n",
    "        essay)\n",
    "    # print(essay)\n",
    "    # 把文本分割为词\n",
    "    if split_by:\n",
    "        word_list = essay.split(split_by)\n",
    "    else:\n",
    "        word_list = list(essay)\n",
    "    # 统计总的单词数\n",
    "    word_number = len(word_list)\n",
    "    print('此文章共有 %s 个单词' % word_number)\n",
    "    # 得到每个单词出现的次数\n",
    "    word_counter = Counter(word_list)\n",
    "    # print('每个单词出现的次数为：%s' % word_counter)\n",
    "    # 使用信息熵公式计算信息熵\n",
    "    ent = -sum([(p / word_number) * log(p / word_number, 2) for p in\n",
    "                word_counter.values()])\n",
    "    print('信息熵为：%.2f' % ent)\n",
    "    return ent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ent = cal_essay_entropy(ch_essay)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ent = cal_essay_entropy(en_essay, split_by = ' ')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
