{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2.4 贝叶斯分析\n",
"贝叶斯分析是一种根据概率统计知识对数据进行分析的方法,属于统计学分类的范畴。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2.4.1 贝叶斯公式\n",
"\n",
"- **频率学派**:从历史数据中计算某个事件的概率,认为只要采样足够多,则事件发生的频率就可以无限逼近真实概率。\n",
"- **贝叶斯学派**:认为某个事件发生的概率不仅与先前这个事件发生的概率相关(称为**先验概率**),也与后期计算该事件概率时所观测的“新近”信息有关(称为**似然概率**)\n",
"\n",
"贝叶斯概率计算公式表达:\n",
"后验概率 = 先验概率 × 似然概率\n",
"\n",
"**条件概率**:\n",
"\n",
"$P(A|B)$:表示事件 $B$ 发生的前提下,事件 $A$ 发生的概率\n",
"\n",
"$$P(A|B)=\\frac{P(A ∩ B)}{P(B)}$$\n",
"\n",
"$P(B|A)$:表示事件 $A$ 发生的前提下,事件 $B$ 发生的概率\n",
"\n",
"$$P(B|A)=\\frac{P(A ∩ B)}{P(A)}$$\n",
"\n",
"由于:\n",
"\n",
"$$P(B|A){P(A)}=P(A|B){P(B)}={P(A ∩ B)}$$\n",
"\n",
"可得**贝叶斯公式**:\n",
"\n",
"$$P(A|B) = \\frac{P(B|A)P(A)}{P(B)}$$\n",
"\n",
"其中:\n",
"- $P(A)$ 是事件 $A$ 发生的先验概率,与事件 $B$ 是否发生无关;\n",
"- $P(B|A)$是事件 $A$ 发生前提下,事件 $B$ 发生的概率,也称为**似然概率**; \n",
"- $P(B)$ 是事件 $B$ 发生的先验概率,也称为**标准化常量**;\n",
"- $P(A|B)$是事件 $B$ 发生前提下,事件 $A$ 发生的概率,也是 $A$ 的后验概率。 "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2.4.2 贝叶斯推断\n",
"贝叶斯推断是一种基于贝叶斯公式进行分析的统计学方法。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"小例子:根据邮件中的 “红包” 字样判别该邮件是不是垃圾邮件"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 广告邮件的数量 \n",
"ad_number = 4000\n",
"# 正常邮件的数量\n",
"normal_number = 6000\n",
"\n",
"# 所有广告邮件中,出现 “红包” 关键词的邮件的数量\n",
"ad_hongbao_number = 1000\n",
"# 所有正常邮件中,出现 “红包” 关键词的邮件的数量\n",
"normal_hongbao_number = 6\n",
"\n",
"# 用户收到广告邮件的先验概率为\n",
"P_ad = ad_number / (ad_number + normal_number)\n",
"print(\"用户收到广告邮件的先验概率为 \" + str(P_ad))\n",
"\n",
"# 用户收到正常邮件的先验概率为\n",
"P_normal = normal_number / (ad_number + normal_number)\n",
"print(\"用户收到正常邮件的先验概率为 \" + str(P_normal))\n",
"\n",
"# 红包出现的概率\n",
"P_hongbao = (normal_hongbao_number + ad_hongbao_number) / (\n",
" ad_number + normal_number)\n",
"print(\"邮件包含红包的先验概率为 \" + str(P_hongbao))\n",
"\n",
"# 广告邮件中出现 “红包” 关键词的条件概率\n",
"P_hongbao_ad = ad_hongbao_number / ad_number\n",
"print(\"广告邮件中出现 “红包” 关键词的条件概率为 \" + str(P_hongbao_ad))\n",
"\n",
"# 正确邮件中出现 “红包” 关键词的条件概率\n",
"P_hongbao_normal = normal_hongbao_number / normal_number\n",
"print(\"正常邮件中出现 “红包” 关键词的条件概率为 \" + str(P_hongbao_normal))\n",
"\n",
"# 根据贝叶斯定理可得\n",
"# 当邮件中出现 “红包” ,其为广告邮件的后验概率\n",
"P_ad_hongbao = P_ad * P_hongbao_ad / P_hongbao\n",
"print(\"当邮件中出现 “红包” ,其为广告邮件的后验概率为 \" + str(P_ad_hongbao))\n",
"\n",
"# 当邮件中出现 “红包” ,其为正常邮件的后验概率\n",
"P_normal_hongbao = P_normal * P_hongbao_normal / P_hongbao\n",
"print(\"当邮件中出现 “红包” ,其为正常邮件的后验概率为 \" + str(P_normal_hongbao))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2.4.3 朴素贝叶斯分类器 \n",
"一种常用的分类算法,其假设**样本各个特征之间相互独立、互不影响**。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"小例子:预测同学会不会在某店铺订餐。\n",
"\n",
"**目标**:根据某同学的订单记录,如果向他推荐一家“价位低、口味偏甜、距离远”的店铺,判断他会下单吗?\n",
"\n",
"**数据**:该同学的下单记录如下\n",
"\n",
"|店铺价位|店铺口味|店铺距离|是否下单|\n",
"|:--:|:--:|:--:|:--:|\n",
"|高|偏甜|近|是|\n",
"|高|清淡|近|否|\n",
"|高|偏辣|远|否|\n",
"|高|偏甜|远|否|\n",
"|低|偏甜|近|是|\n",
"|低|偏甜|近|是|\n",
"|低|清淡|远|否|\n",
"|低|偏辣|远|是|\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"该同学在收到8次推荐后,下单4次和没有下单4次,则其“下单”,“不下单”的概率: \n",
"$$P(下单) = \\frac{4}{8}=0.5$$ \n",
"$$P(不下单) = \\frac{4}{8}=0.5$$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"该同学对 “价位低、口味偏甜、距离远” 这次推荐的 “下单” 或 “不下单” 的似然概率为(注意基本假设是店铺价位、口味、距离这些特质中间互相独立,互不影响):\n",
"\n",
"$$\n",
"\\begin{align}\n",
"&P(价位=低,口味=偏甜,距离=远|下单)\\\\\n",
"=&P(价位=低|下单)×P(口味=偏甜|下单)×P(距离=远|下单)\\\\\n",
"=&\\frac{3}{4}×\\frac{3}{4}×\\frac{1}{4}\\\\\n",
"≈ & 0.141\n",
"& \\\\\n",
"& \\\\\n",
"& P(价位=低,口味=偏甜,距离=远|不下单)\\\\\n",
"=&P(价位=低|不下单)×P(口味=偏甜|不下单)×P(距离=远|不下单)\\\\\n",
"=&\\frac{1}{4}×\\frac{1}{4}×\\frac{3}{4}\\\\\n",
"≈ &0.047\n",
"\\end{align}\n",
"$$\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"根据贝叶斯公式,可以得到该同学在一家“价格低、口味偏甜、距离远”的店铺,\n",
"\n",
"下单的后验概率为:\n",
"\n",
"$$\n",
"\\begin{align}\n",
"&P(下单|价位=低,口味=偏甜,距离=远)\\\\\n",
"=&P(下单)×P(价位=低,口味=偏甜,距离=远|下单)\\\\\n",
"=&0.5×0.141\\\\\n",
"= &0.0705\n",
"\\end{align}\n",
"$$\n",
"\n",
"不下单的后验概率为:\n",
"$$\n",
"\\begin{align}\n",
"&P(不下单|价位=低,口味=偏甜,距离=远)\\\\\n",
"=&P(不下单)×P(价位=低,口味=偏甜,距离=远|不下单)\\\\\n",
"=&0.5×0.047\\\\\n",
"=&0.0235\n",
"\\end{align}\n",
"$$\n",
"\n",
"\n",
"由此可见,该同学这次会下单的概率大于不下单的概率。\n",
"\n",
"上面的计算过程进行了一些简化,本来应该计算如下两个公式:\n",
"\n",
"$$\n",
"\\begin{align}\n",
"&P(下单|价位=低,口味=偏甜,距离=远)\\\\\n",
"=&\\frac{P(下单)×P(价位=低,口味=偏甜,距离=远|下单)}{P(价位=低,口味=偏甜,距离=远)}\\\\\n",
"\\end{align}\n",
"$$\n",
"\n",
"$$\n",
"\\begin{align}\n",
"&P(不下单|价位=低,口味=偏甜,距离=远)\\\\\n",
"=&\\frac{P(不下单)×P(价位=低,口味=偏甜,距离=远|不下单)}{P(价位=低,口味=偏甜,距离=远)}\\\\\n",
"\\end{align}\n",
"$$\n",
"\n",
"上述两个计算公式分母相同,对计算结果不影响,因此就从计算过程中略去了。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 实践与体验"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 利用朴素贝叶斯分类器解决 MNIST 手写体数字识别问题\n",
"\n",
"**MNIST** 是一个手写体数据集,它包含了各种各样的手写体数字图像及其对应的数字标签。其中每幅手写体图像的大小为 **28×28** ,共有 **784** 个像素点,可记为一个 **784** 维的向量,每个 **784** 维向量对应着一个标签。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"本次实验我们利用 **tensorflow** 库来来进行原始数据集的解析和读取,利用 **sklearn** 库来进行特征提取和分类。更多内容可参考**tensorflow** 的[数据集部分](https://www.tensorflow.org/datasets/),sklearn 的 [bayes部分](https://scikit-learn.org/stable/modules/naive_bayes.html)。\n",
" \n",
"1.在 **Python** 中导入相应库。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"import numpy as np\n",
"from tensorflow.keras.datasets import mnist\n",
"from sklearn.naive_bayes import BernoulliNB\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"2.读取 **MNIST** 训练集和测试集。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"读取数据中 ...\")\n",
"(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
"train_images = train_images.reshape(train_images.shape[0], 784)\n",
"test_images = test_images.reshape(test_images.shape[0], 784)\n",
"print('读取完毕!')\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"我们使用下面的方法来查看其中几张图片。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def plot_images(imgs):\n",
" \"\"\"绘制几个样本图片\n",
" :param show: 是否显示绘图\n",
" :return:\n",
" \"\"\"\n",
" sample_num = min(9, len(imgs))\n",
" img_figure = plt.figure(1)\n",
" img_figure.set_figwidth(5)\n",
" img_figure.set_figheight(5)\n",
" for index in range(0, sample_num):\n",
" ax = plt.subplot(3, 3, index + 1)\n",
" ax.imshow(imgs[index].reshape(28, 28), cmap='gray')\n",
" ax.grid(False)\n",
" plt.margins(0, 0)\n",
" plt.show()\n",
"\n",
"\n",
"plot_images(train_images)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"3.根据 **MNIST** 训练集训练朴素贝叶斯分类器"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"初始化并训练贝叶斯模型...\")\n",
"classifier_BNB = BernoulliNB()\n",
"classifier_BNB.fit(train_images,train_labels)\n",
"print('训练完成!')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"4.根据训练出的分类器对 **MNIST** 测试集中的图片进行识别,得到预测值。\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(\"测试训练好的贝叶斯模型...\")\n",
"test_predict_BNB = classifier_BNB.predict(test_images)\n",
"print(\"测试完成!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"5.将测试图片的预测值与实际值相比较,计算并输出分类器的正确率。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"accuracy = sum(test_predict_BNB==test_labels)/len(test_labels)\n",
"print('贝叶斯分类模型在测试集上的准确率为 :',accuracy)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"6.对实验结果进行分析比较,列出 **0-9** 不同数字识别的准确率,比较其差异。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 记录每个类别的样本的个数,例如 {0:100} 即 数字为 0 的图片有 100 张 \n",
"class_num = {}\n",
"# 每个类别预测为 0-9 类别的个数,\n",
"predict_num = []\n",
"# 每个类别预测的准确率\n",
"class_accuracy = {}\n",
"\n",
"for i in range(10):\n",
" # 找到类别是 i 的下标\n",
" class_is_i_index = np.where(test_labels == i)[0]\n",
" # 统计类别是 i 的个数\n",
" class_num[i] = len(class_is_i_index)\n",
"\n",
" # 统计类别 i 预测为 0-9 各个类别的个数\n",
" predict_num.append(\n",
" [sum(test_predict_BNB[class_is_i_index] == e) for e in range(10)])\n",
"\n",
" # 统计类别 i 预测的准确率\n",
" class_accuracy[i] = round(predict_num[i][i] / class_num[i], 3) * 100\n",
"\n",
" print(\"数字 %s 的样本个数:%4s,预测正确的个数:%4s,准确率:%.4s%%\" % (\n",
" i, class_num[i], predict_num[i][i], class_accuracy[i]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import seaborn as sns\n",
"import matplotlib.pyplot as plt\n",
"\n",
"sns.set(rc={'figure.figsize': (12, 8)})\n",
"np.random.seed(0)\n",
"uniform_data = predict_num\n",
"ax = sns.heatmap(uniform_data, cmap='YlGnBu', vmin=0, vmax=150)\n",
"ax.set_xlabel('真实值')\n",
"ax.set_ylabel('预测值')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"通过热力图,我们看到 3 经常被错认为 5 和 8, 4 和 9 经常互相错认。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"我们看看真实标签为 9,但是预测为 4 的错认的照片\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_imgs(images, true_labels, predict_labels, true_label,\n",
" predict_label):\n",
" \"\"\"\n",
" 从全部图片中按真实标签和预测标签筛选出图片\n",
" :param images: 一组图片\n",
" :param true_labels: 每张图片的标签\n",
" :param predict_labels: 模型预测的每张图片的标签\n",
" :param true_label: 希望取得的图片的真实标签\n",
" :param predict_label: 希望取得的图片的预测标签\n",
" :return: \n",
" \"\"\"\n",
" # 所有类别为 true_label 的样本的 index 值\n",
" true_label_index = set(np.where(true_labels == true_label)[0])\n",
" # 所有预测类别为 predict_label 的样本的 index 值\n",
" predict_label_index = set(np.where(predict_labels == predict_label)[0])\n",
" # 取交集,即为真实类别为 true_label, 预测结果为 predict_label 的样本的 index 值\n",
" res = list(true_label_index & predict_label_index)\n",
" return images[res]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"imgs = get_imgs(test_images, test_labels, test_predict_BNB, 9, 4)\n",
"plot_images(imgs)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}