cf9f12d
joyvan 6 years ago
12 changed file(s) with 2409 addition(s) and 747 deletion(s). Raw diff Collapse all Expand all
1010 "cell_type": "markdown",
1111 "metadata": {},
1212 "source": [
13 "<center><video src=\"http://files.momodel.cn/search_problem.mp4\" controls=\"controls\" width=800px></center>"
14 ]
15 },
16 {
17 "cell_type": "markdown",
18 "metadata": {},
19 "source": [
1320 "现实世界中许多问题都可以通过搜索的方法来求解,例如设计最佳出行路线或是制订合理的课程表。当给定一个待求解问题后,搜索算法会按照事先设定的逻辑来自动寻找符合求解问题的答案,因此一般可将搜索算法称为问题求解智能体。"
1421 ]
1522 },
2532 "metadata": {},
2633 "source": [
2734 "## 2.1.1 搜索算法基本概念"
35 ]
36 },
37 {
38 "cell_type": "markdown",
39 "metadata": {},
40 "source": [
41 "<center><video src=\"http://files.momodel.cn/search_basic_concept.mp4\" controls=\"controls\" width=800px></center>"
2842 ]
2943 },
3044 {
8599 "cell_type": "markdown",
86100 "metadata": {},
87101 "source": [
88 "观察上图,可以看到从起点 A 到目标点 G 距离最短的路径是 A -> B -> D -> G,其距离是 38,我们可以设计一个计算机程序,按照既定的规则,从起点 A 出发,不断尝试从一个节点移动到下一个节点,直到抵达目标点 G。\n",
89 "\n",
90 "在详细描述搜索算法之前,先了解下面四个重要的概念。"
91 ]
92 },
93 {
94 "cell_type": "markdown",
95 "metadata": {},
96 "source": [
97 "+ **状态**。 在上面的例子中状态是什么?"
98 ]
99 },
100 {
101 "cell_type": "markdown",
102 "metadata": {},
103 "source": [
104 "+ **测试目标**。 在上面的例子中测试目标是什么?"
105 ]
106 },
107 {
108 "cell_type": "markdown",
109 "metadata": {},
110 "source": [
111 "+ **动作**。在上面的例子中动作是什么?"
112 ]
113 },
114 {
115 "cell_type": "markdown",
116 "metadata": {},
117 "source": [
118 "+ **路径**。在上面的例子中路径是什么?\n",
119 "\n"
120 ]
121 },
122 {
123 "cell_type": "markdown",
124 "metadata": {},
125 "source": [
126 "更改下面的代码,设置一条路径并查看。"
102 "观察上图,可以看到从起点 A 到目标点 G 距离最短的路径是 A -> B -> D -> G,其距离是 38,我们可以设计一个计算机程序,按照既定的规则,从起点 A 出发,不断尝试从一个节点移动到下一个节点,直到抵达目标点 G。"
103 ]
104 },
105 {
106 "cell_type": "markdown",
107 "metadata": {},
108 "source": [
109 "**想一想**"
110 ]
111 },
112 {
113 "cell_type": "markdown",
114 "metadata": {},
115 "source": [
116 "搜索算法四个重要的概念分别是什么?"
117 ]
118 },
119 {
120 "cell_type": "markdown",
121 "metadata": {},
122 "source": [
123 "**动手练**"
124 ]
125 },
126 {
127 "cell_type": "markdown",
128 "metadata": {},
129 "source": [
130 "更改下面的代码,设置一条路径并查看,并结合此图说明搜索算法的四个重要概念。"
127131 ]
128132 },
129133 {
146150 "cell_type": "markdown",
147151 "metadata": {},
148152 "source": [
149 "搜索算法就是不断从某一状态转移到下一状态,直至到达终止状态为止。\n",
150 "\n",
151 "**搜索树**是什么?\n",
152 "\n",
153 "如何构造搜索树 ?\n",
153 "<center><video src=\"http://files.momodel.cn/search_tree.mp4\" controls=\"controls\" width=800px></center>"
154 ]
155 },
156 {
157 "cell_type": "markdown",
158 "metadata": {},
159 "source": [
160 "搜索算法就是不断从某一状态转移到下一状态,直至到达终止状态为止。"
161 ]
162 },
163 {
164 "cell_type": "markdown",
165 "metadata": {},
166 "source": [
167 "**想一想**"
168 ]
169 },
170 {
171 "cell_type": "markdown",
172 "metadata": {},
173 "source": [
174 "搜索树是什么?\n",
175 "\n",
176 "如何构造搜索树?\n",
154177 "\n",
155178 "思考一下,路径搜索能出现回路吗?\n",
156179 "\n",
178201 "cell_type": "markdown",
179202 "metadata": {},
180203 "source": [
204 "<center><video src=\"http://files.momodel.cn/search_dfs_bfs.mp4\" controls=\"controls\" width=800px></center>"
205 ]
206 },
207 {
208 "cell_type": "markdown",
209 "metadata": {},
210 "source": [
181211 "<img src=\"http://imgbed.momodel.cn//20200110151426.png\" width=600>"
182212 ]
183213 },
217247 "cell_type": "markdown",
218248 "metadata": {},
219249 "source": [
250 "**想一想**"
251 ]
252 },
253 {
254 "cell_type": "markdown",
255 "metadata": {},
256 "source": [
220257 "对于一个搜索问题,只要存在答案(即从初始节点到终止节点存在满足条件的一条路径),那么深度优先搜索和广度优先搜索都能找到一个答案吗?找到的答案一定是路径最短的吗?"
221258 ]
222259 },
224261 "cell_type": "markdown",
225262 "metadata": {},
226263 "source": [
264 "### 扩展内容\n",
265 "** 深度优先搜索 dfs ** 基础代码解读\n",
266 "\n",
267 "```\n",
268 "def iter_dfs(G, start, target):\n",
269 " '''\n",
270 " 深度优先搜索\n",
271 " :param G: 字典,存储每个点的相邻点\n",
272 " :param start: 初始点\n",
273 " :param target: 目标点\n",
274 " :return:\n",
275 " '''\n",
276 "\n",
277 " # 定义已访问的点的集合\n",
278 " S = set()\n",
279 " # 定义一个待访问点的列表\n",
280 " Q = []\n",
281 " # 把初始点放进列表中\n",
282 " Q.append(start)\n",
283 " while Q:\n",
284 " # 只要带访问的列表不为空,那么从列表中拿取最后一个元素,也就是一个点,记作 u\n",
285 " u = Q.pop()\n",
286 " # 如果当前点是目标点,则结束查找\n",
287 " if u == target:\n",
288 " break\n",
289 " # 如果该点已经被访问了,则跳过此点\n",
290 " if u in S:\n",
291 " continue\n",
292 " # 访问此点,将点加入已访问点的结合 S 中\n",
293 " S.add(u)\n",
294 " # 将点 u 相邻的点放入待访问的列表中\n",
295 " Q.extend(G[u])\n",
296 "```"
297 ]
298 },
299 {
300 "cell_type": "markdown",
301 "metadata": {},
302 "source": [
227303 "## 2.1.4 启发式搜索"
228304 ]
229305 },
231307 "cell_type": "markdown",
232308 "metadata": {},
233309 "source": [
234 "能否在搜索过程中利用问题的定义以外**辅助信息**?"
310 "<center><video src=\"http://files.momodel.cn/search_greedy.mp4\" controls=\"controls\" width=800px></center>"
311 ]
312 },
313 {
314 "cell_type": "markdown",
315 "metadata": {},
316 "source": [
317 "**想一想**"
318 ]
319 },
320 {
321 "cell_type": "markdown",
322 "metadata": {},
323 "source": [
324 "能否在搜索过程中利用问题的定义以外的辅助信息?"
235325 ]
236326 },
237327 {
265355 "cell_type": "markdown",
266356 "metadata": {},
267357 "source": [
358 "**想一想**"
359 ]
360 },
361 {
362 "cell_type": "markdown",
363 "metadata": {},
364 "source": [
268365 "“贪婪”机制下找到的最佳路径是什么呢?它是最短路径吗?为什么会产生这样的搜索结果?\n"
269366 ]
270367 },
272369 "cell_type": "markdown",
273370 "metadata": {},
274371 "source": [
275 "如何克服贪婪算法的不足?\n",
276 " A\\*算法\n",
277 "\n",
278 " A\\*算法搜索过程:\n"
372 "另一种启发式搜索算法—— A* 算法克服了贪婪算法的不足。"
373 ]
374 },
375 {
376 "cell_type": "markdown",
377 "metadata": {},
378 "source": [
379 "<center><video src=\"http://files.momodel.cn/search_a_star.mp4\" controls=\"controls\" width=800px></center>"
380 ]
381 },
382 {
383 "cell_type": "markdown",
384 "metadata": {},
385 "source": [
386 "A\\*算法搜索过程:"
279387 ]
280388 },
281389 {
299407 "# 可以调整辅助信息的比重\n",
300408 "# 当只考虑额外信息时,即 origin_info_weight 设置为 0 的时候,A* 算法退化为贪婪算法。\n",
301409 "g.animation_search_tree('a_star',help_info_weight=1, origin_info_weight=0)"
410 ]
411 },
412 {
413 "cell_type": "markdown",
414 "metadata": {},
415 "source": [
416 "**想一想**"
302417 ]
303418 },
304419 {
1010 "cell_type": "markdown",
1111 "metadata": {},
1212 "source": [
13 "<center><video src=\"http://files.momodel.cn/search_problem.mp4\" controls=\"controls\" width=800px></center>"
14 ]
15 },
16 {
17 "cell_type": "markdown",
18 "metadata": {},
19 "source": [
20 "<br>"
21 ]
22 },
23 {
24 "cell_type": "markdown",
25 "metadata": {},
26 "source": [
1327 "现实世界中许多问题都可以通过搜索的方法来求解,例如设计最佳出行路线或是制订合理的课程表。当给定一个待求解问题后,搜索算法会按照事先设定的逻辑来自动寻找符合求解问题的答案,因此一般可将搜索算法称为问题求解智能体。"
1428 ]
1529 },
2438 "cell_type": "markdown",
2539 "metadata": {},
2640 "source": [
41 "<br>"
42 ]
43 },
44 {
45 "cell_type": "markdown",
46 "metadata": {},
47 "source": [
2748 "## 2.1.1 搜索算法基本概念"
49 ]
50 },
51 {
52 "cell_type": "markdown",
53 "metadata": {},
54 "source": [
55 "<center><video src=\"http://files.momodel.cn/search_basic_concept.mp4\" controls=\"controls\" width=800px></center>"
56 ]
57 },
58 {
59 "cell_type": "markdown",
60 "metadata": {},
61 "source": [
62 "<br>"
2863 ]
2964 },
3065 {
85120 "cell_type": "markdown",
86121 "metadata": {},
87122 "source": [
88 "观察上图,可以看到从起点 A 到目标点 G 距离最短的路径是 A -> B -> D -> G,其距离是 38,我们可以设计一个计算机程序,按照既定的规则,从起点 A 出发,不断尝试从一个节点移动到下一个节点,直到抵达目标点 G。\n",
89 "\n",
123 "观察上图,可以看到从起点 A 到目标点 G 距离最短的路径是 A -> B -> D -> G,其距离是 38,我们可以设计一个计算机程序,按照既定的规则,从起点 A 出发,不断尝试从一个节点移动到下一个节点,直到抵达目标点 G。"
124 ]
125 },
126 {
127 "cell_type": "markdown",
128 "metadata": {},
129 "source": [
130 "<br>"
131 ]
132 },
133 {
134 "cell_type": "markdown",
135 "metadata": {},
136 "source": [
90137 "在详细描述搜索算法之前,先看看下面四个重要的概念。"
91138 ]
92139 },
109156 "metadata": {},
110157 "source": [
111158 "+ **动作**。动作指的是搜索算法从一个状态转变到另外一个状态所采取的行为。一般假设在每个状态下所能够采取的行为数量都是有限的。例如:在起点 A,只有 B 和 C 两个节点与之相连,所以只有转移到 B 或者转移到 C 这两种选择。一般情况从一个状态到另外一个状态的过程叫做**状态转移**。"
112 ]
113 },
114 {
115 "cell_type": "markdown",
116 "metadata": {},
117 "source": [
118 "下图中,我们在初始状态采取了转移到 B 这个动作。"
119 ]
120 },
121 {
122 "cell_type": "code",
123 "execution_count": null,
124 "metadata": {},
125 "outputs": [],
126 "source": [
127 "g.show_graph(this_path=\"AB\")"
128159 ]
129160 },
130161 {
136167 ]
137168 },
138169 {
170 "cell_type": "markdown",
171 "metadata": {},
172 "source": [
173 "看下图,理解上面的四个概念。"
174 ]
175 },
176 {
139177 "cell_type": "code",
140178 "execution_count": null,
141179 "metadata": {},
148186 "cell_type": "markdown",
149187 "metadata": {},
150188 "source": [
189 "<br>"
190 ]
191 },
192 {
193 "cell_type": "markdown",
194 "metadata": {},
195 "source": [
151196 "## 2.1.2 搜索算法"
152197 ]
153198 },
155200 "cell_type": "markdown",
156201 "metadata": {},
157202 "source": [
203 "<center><video src=\"http://files.momodel.cn/search_tree.mp4\" controls=\"controls\" width=800px></center>"
204 ]
205 },
206 {
207 "cell_type": "markdown",
208 "metadata": {},
209 "source": [
210 "<br>"
211 ]
212 },
213 {
214 "cell_type": "markdown",
215 "metadata": {},
216 "source": [
158217 "搜索算法就是不断从某一状态转移到下一状态,直至到达终止状态为止。\n",
159218 "\n",
160219 "\n",
168227 "outputs": [],
169228 "source": [
170229 "g.show_search_tree()"
230 ]
231 },
232 {
233 "cell_type": "markdown",
234 "metadata": {},
235 "source": [
236 "<br>"
171237 ]
172238 },
173239 {
194260 "cell_type": "markdown",
195261 "metadata": {},
196262 "source": [
263 "<br>"
264 ]
265 },
266 {
267 "cell_type": "markdown",
268 "metadata": {},
269 "source": [
197270 "## 2.1.3 深度优先搜索和广度优先搜索"
198271 ]
199272 },
201274 "cell_type": "markdown",
202275 "metadata": {},
203276 "source": [
277 "<center><video src=\"http://files.momodel.cn/search_dfs_bfs.mp4\" controls=\"controls\" width=800px></center>"
278 ]
279 },
280 {
281 "cell_type": "markdown",
282 "metadata": {},
283 "source": [
284 "<br>"
285 ]
286 },
287 {
288 "cell_type": "markdown",
289 "metadata": {},
290 "source": [
204291 "<img src=\"http://imgbed.momodel.cn//20200110151426.png\" width=600>"
205292 ]
206293 },
228315 ]
229316 },
230317 {
318 "cell_type": "markdown",
319 "metadata": {},
320 "source": [
321 "<br>"
322 ]
323 },
324 {
231325 "cell_type": "code",
232326 "execution_count": null,
233327 "metadata": {},
241335 "metadata": {},
242336 "source": [
243337 "需要强调的是,对于一个搜索问题,只要存在答案(即从初始节点到终止节点存在满足条件的一条路径),那么排除了回路的深度优先搜索和广度优先搜索均能找到一个答案,但是这个找到的答案不一定是最优的,例如距离最短。"
338 ]
339 },
340 {
341 "cell_type": "markdown",
342 "metadata": {},
343 "source": [
344 "<br>"
244345 ]
245346 },
246347 {
286387 "cell_type": "markdown",
287388 "metadata": {},
288389 "source": [
390 "<br>"
391 ]
392 },
393 {
394 "cell_type": "markdown",
395 "metadata": {},
396 "source": [
289397 "## 2.1.4 启发式搜索"
398 ]
399 },
400 {
401 "cell_type": "markdown",
402 "metadata": {},
403 "source": [
404 "<center><video src=\"http://files.momodel.cn/search_greedy.mp4\" controls=\"controls\" width=800px></center>"
405 ]
406 },
407 {
408 "cell_type": "markdown",
409 "metadata": {},
410 "source": [
411 "<br>"
290412 ]
291413 },
292414 {
327449 "cell_type": "markdown",
328450 "metadata": {},
329451 "source": [
452 "<br>"
453 ]
454 },
455 {
456 "cell_type": "markdown",
457 "metadata": {},
458 "source": [
330459 "但是在“贪婪”机制下找到的路径 A -> C -> D -> G 并非最短路径。产生这样的搜索结果,其原因是:最佳优先算法在当前节点时,每次均贪婪的从当前节点相邻的节点中选择**与目标节点直线距离最近的节点**,作为后续节点。这样就会造成贪婪最佳优先算法**过于重视当前的最优,而忽视了全局最优**。\n"
331460 ]
332461 },
334463 "cell_type": "markdown",
335464 "metadata": {},
336465 "source": [
337 "另一种启发式搜索算法—— A\\* 算法克服了这一不足。\n",
338 "\n",
466 "另一种启发式搜索算法—— A\\* 算法克服了这一不足。"
467 ]
468 },
469 {
470 "cell_type": "markdown",
471 "metadata": {},
472 "source": [
473 "<center><video src=\"http://files.momodel.cn/search_a_star.mp4\" controls=\"controls\" width=800px></center>"
474 ]
475 },
476 {
477 "cell_type": "markdown",
478 "metadata": {},
479 "source": [
480 "<br>"
481 ]
482 },
483 {
484 "cell_type": "markdown",
485 "metadata": {},
486 "source": [
339487 "其算法思路是:将初始节点到目标节点的距离分成两部分,\n",
340488 "- 初始节点到当前节点的路径代价;\n",
341489 "- 当前节点到目标节点之间的直线距离。将两者之和作为评价函数的取值大小。\n",
345493 "+ 函数 h(n): 表示当前节点 n 到目标节点的直线距离。函数 h(n) 也称为**启发函数**。\n",
346494 "\n",
347495 "\n",
348 " A\\* 算法搜索过程:\n"
496 " A\\* 算法搜索过程:"
349497 ]
350498 },
351499 {
361509 ]
362510 },
363511 {
512 "cell_type": "markdown",
513 "metadata": {},
514 "source": [
515 "<br>"
516 ]
517 },
518 {
364519 "cell_type": "code",
365520 "execution_count": null,
366521 "metadata": {},
378533 "与贪婪最佳优先算法不一定能够找到最短路径不同,A\\* 算法找到的路径一定是最短路径;另一方面,由于A\\* 算法能够利用辅助信息,因此它比其他算法用更少的步骤。\n",
379534 "\n",
380535 "在实际中,A\\* 算法的性能表现取决于启发函数的设计,只要定义一个合适的启发函数,A\\* 算法就能够大幅缩减搜索所需的时间。"
536 ]
537 },
538 {
539 "cell_type": "markdown",
540 "metadata": {},
541 "source": [
542 "<br>"
381543 ]
382544 },
383545 {
447609 "source": [
448610 "# 查看 dfs 的搜索过程\n",
449611 "h_graph.animation_search_tree('dfs')"
612 ]
613 },
614 {
615 "cell_type": "markdown",
616 "metadata": {},
617 "source": [
618 "<br>"
450619 ]
451620 },
452621 {
1010 "cell_type": "markdown",
1111 "metadata": {},
1212 "source": [
13 "决策树是一种通过**树形结构**进行分类的方法,使用层层推理来实现最终的分类。决策树由下面几种元素构成:\n",
14 "\n",
15 "<img src=\"http://imgbed.momodel.cn//20200110170450.png\" width=500>\n"
16 ]
17 },
18 {
19 "cell_type": "markdown",
20 "metadata": {},
21 "source": [
22 "决策树的组成元素有哪些?"
13 "决策树是一种通过**树形结构**进行分类的方法,使用层层推理来实现最终的分类。\n",
14 "\n",
15 "决策树由下面几种元素构成:\n",
16 "+ 根节点:最顶层的分类条件。\n",
17 "+ 决策节点(中间节点):中间分类条件。\n",
18 "+ 叶子节点:代表标签类别。\n",
19 "\n",
20 "<img src=\"http://imgbed.momodel.cn//20200110170450.png\" width=400>\n"
2321 ]
2422 },
2523 {
3129 "贷款用户主要具备三个属性:**是否拥有房产**,**是否结婚**,**平均月收入**。\n",
3230 "\n",
3331 "每一个内部节点都表示一个属性条件判断,叶子节点表示贷款用户是否具有偿还能力。\n",
34 "<img src=\"http://imgbed.momodel.cn//20200110171836.png\" width=500>\n"
32 "\n",
33 "<img src=\"http://imgbed.momodel.cn//20200110171836.png\" width=400>\n"
3534 ]
3635 },
3736 {
4544 "cell_type": "markdown",
4645 "metadata": {},
4746 "source": [
47 "**想一想**"
48 ]
49 },
50 {
51 "cell_type": "markdown",
52 "metadata": {},
53 "source": [
4854 "决策树的流程是什么?\n",
4955 "\n",
5056 "在有一个贷款用户A,其情况是月收入 3K、已经结婚、没有房产,那么他是否具有偿还贷款的能力呢? \n",
5157 "\n",
52 "上图中我们为啥要用“是否拥有房产”作根节点呢?可不可以用“是否结婚”和“平均月收入”做根节点呢?"
58 "上图中我们为什么要用“是否拥有房产”作根节点呢?可不可以用“是否结婚”和“平均月收入”做根节点呢?"
5359 ]
5460 },
5561 {
5763 "metadata": {},
5864 "source": [
5965 "## 2.2.1 决策树分类概念"
66 ]
67 },
68 {
69 "cell_type": "markdown",
70 "metadata": {},
71 "source": [
72 "<center><video src=\"http://files.momodel.cn/decision_tree_playground_demo.mp4\" controls=\"controls\" width=800px></center>"
6073 ]
6174 },
6275 {
94107 "cell_type": "markdown",
95108 "metadata": {},
96109 "source": [
97 "根节点是天气状况,具有雨、多云和晴三种属性取值。\n",
98 "+ 多云: 样本子集是 { 3, 7, 12, 13 } ,仅有“前往游乐场游玩”一个类别,即肯定去游乐场。 \n",
110 "根节点是天气状况,具有 **雨**、 **多云** 和 **晴** 三种属性取值。\n",
111 "+ **多云**: 样本子集是 { 3, 7, 12, 13 } ,仅有“前往游乐场游玩”一个类别,即肯定去游乐场。 \n",
99112 " \n",
100113 " \n",
101 "+ 晴: 样本子集是 { 1, 2, 8, 9, 11 }\n",
114 "+ **晴**: 样本子集是 { 1, 2, 8, 9, 11 }\n",
102115 " + 湿度大于 75:样本子集为 { 1, 2, 8 },不前往游乐场。\n",
103116 " + 湿度不大于 75:样本子集 { 9, 11 },前往游乐场。\n",
104117 " \n",
105118 " \n",
106 "+ 雨:样本子集为 { 4, 5, 6, 10, 14 }\n",
119 "+ **雨**:样本子集为 { 4, 5, 6, 10, 14 }\n",
107120 " + 有风:样本子集 { 6, 14 },不去游乐场。\n",
108121 " + 无风:样本子集 { 4, 5, 10 },前往游乐场。\n",
109122 " "
113126 "cell_type": "markdown",
114127 "metadata": {},
115128 "source": [
116 "由上面的例子可以看到,构建决策树的过程就是:\n",
117 "1. 选择一个属性值;\n",
118 "2. 基于该属性对样本集进行划分;\n",
119 "3. 重复步骤 1 和 2 直到最后所得划分结果中每个样本为同一类别。"
120 ]
121 },
122 {
123 "cell_type": "markdown",
124 "metadata": {},
125 "source": [
126 "把数据导入 DataFrame 数据结构:"
127 ]
128 },
129 {
130 "cell_type": "code",
131 "execution_count": 1,
129 "**想一想**"
130 ]
131 },
132 {
133 "cell_type": "markdown",
134 "metadata": {},
135 "source": [
136 "通过上面的例子,你观察到构建决策树的过程是哪几步?"
137 ]
138 },
139 {
140 "cell_type": "markdown",
141 "metadata": {},
142 "source": [
143 "下面,我们创建数据并进行一些预处理"
144 ]
145 },
146 {
147 "cell_type": "code",
148 "execution_count": null,
132149 "metadata": {},
133150 "outputs": [],
134151 "source": [
145162 },
146163 {
147164 "cell_type": "code",
148 "execution_count": 2,
149 "metadata": {},
150 "outputs": [
151 {
152 "data": {
153 "text/html": [
154 "<div>\n",
155 "<style scoped>\n",
156 " .dataframe tbody tr th:only-of-type {\n",
157 " vertical-align: middle;\n",
158 " }\n",
159 "\n",
160 " .dataframe tbody tr th {\n",
161 " vertical-align: top;\n",
162 " }\n",
163 "\n",
164 " .dataframe thead th {\n",
165 " text-align: right;\n",
166 " }\n",
167 "</style>\n",
168 "<table border=\"1\" class=\"dataframe\">\n",
169 " <thead>\n",
170 " <tr style=\"text-align: right;\">\n",
171 " <th></th>\n",
172 " <th>天气</th>\n",
173 " <th>温度</th>\n",
174 " <th>湿度</th>\n",
175 " <th>是否有风</th>\n",
176 " <th>是否前往游乐场</th>\n",
177 " </tr>\n",
178 " </thead>\n",
179 " <tbody>\n",
180 " <tr>\n",
181 " <th>0</th>\n",
182 " <td>晴</td>\n",
183 " <td>&gt;26</td>\n",
184 " <td>&gt;75</td>\n",
185 " <td>否</td>\n",
186 " <td>0</td>\n",
187 " </tr>\n",
188 " <tr>\n",
189 " <th>1</th>\n",
190 " <td>晴</td>\n",
191 " <td>&lt;=26</td>\n",
192 " <td>&gt;75</td>\n",
193 " <td>是</td>\n",
194 " <td>0</td>\n",
195 " </tr>\n",
196 " <tr>\n",
197 " <th>2</th>\n",
198 " <td>多云</td>\n",
199 " <td>&gt;26</td>\n",
200 " <td>&gt;75</td>\n",
201 " <td>否</td>\n",
202 " <td>1</td>\n",
203 " </tr>\n",
204 " <tr>\n",
205 " <th>3</th>\n",
206 " <td>雨</td>\n",
207 " <td>&lt;=26</td>\n",
208 " <td>&gt;75</td>\n",
209 " <td>否</td>\n",
210 " <td>1</td>\n",
211 " </tr>\n",
212 " <tr>\n",
213 " <th>4</th>\n",
214 " <td>雨</td>\n",
215 " <td>&lt;=26</td>\n",
216 " <td>&gt;75</td>\n",
217 " <td>否</td>\n",
218 " <td>1</td>\n",
219 " </tr>\n",
220 " <tr>\n",
221 " <th>5</th>\n",
222 " <td>雨</td>\n",
223 " <td>&lt;=26</td>\n",
224 " <td>&lt;=75</td>\n",
225 " <td>是</td>\n",
226 " <td>0</td>\n",
227 " </tr>\n",
228 " <tr>\n",
229 " <th>6</th>\n",
230 " <td>多云</td>\n",
231 " <td>&lt;=26</td>\n",
232 " <td>&lt;=75</td>\n",
233 " <td>是</td>\n",
234 " <td>1</td>\n",
235 " </tr>\n",
236 " <tr>\n",
237 " <th>7</th>\n",
238 " <td>晴</td>\n",
239 " <td>&lt;=26</td>\n",
240 " <td>&gt;75</td>\n",
241 " <td>否</td>\n",
242 " <td>0</td>\n",
243 " </tr>\n",
244 " <tr>\n",
245 " <th>8</th>\n",
246 " <td>晴</td>\n",
247 " <td>&lt;=26</td>\n",
248 " <td>&lt;=75</td>\n",
249 " <td>否</td>\n",
250 " <td>1</td>\n",
251 " </tr>\n",
252 " <tr>\n",
253 " <th>9</th>\n",
254 " <td>雨</td>\n",
255 " <td>&lt;=26</td>\n",
256 " <td>&gt;75</td>\n",
257 " <td>否</td>\n",
258 " <td>1</td>\n",
259 " </tr>\n",
260 " <tr>\n",
261 " <th>10</th>\n",
262 " <td>晴</td>\n",
263 " <td>&lt;=26</td>\n",
264 " <td>&lt;=75</td>\n",
265 " <td>是</td>\n",
266 " <td>1</td>\n",
267 " </tr>\n",
268 " <tr>\n",
269 " <th>11</th>\n",
270 " <td>多云</td>\n",
271 " <td>&lt;=26</td>\n",
272 " <td>&gt;75</td>\n",
273 " <td>是</td>\n",
274 " <td>1</td>\n",
275 " </tr>\n",
276 " <tr>\n",
277 " <th>12</th>\n",
278 " <td>多云</td>\n",
279 " <td>&gt;26</td>\n",
280 " <td>&lt;=75</td>\n",
281 " <td>否</td>\n",
282 " <td>1</td>\n",
283 " </tr>\n",
284 " <tr>\n",
285 " <th>13</th>\n",
286 " <td>雨</td>\n",
287 " <td>&lt;=26</td>\n",
288 " <td>&gt;75</td>\n",
289 " <td>是</td>\n",
290 " <td>0</td>\n",
291 " </tr>\n",
292 " </tbody>\n",
293 "</table>\n",
294 "</div>"
295 ],
296 "text/plain": [
297 " 天气 温度 湿度 是否有风 是否前往游乐场\n",
298 "0 晴 >26 >75 否 0\n",
299 "1 晴 <=26 >75 是 0\n",
300 "2 多云 >26 >75 否 1\n",
301 "3 雨 <=26 >75 否 1\n",
302 "4 雨 <=26 >75 否 1\n",
303 "5 雨 <=26 <=75 是 0\n",
304 "6 多云 <=26 <=75 是 1\n",
305 "7 晴 <=26 >75 否 0\n",
306 "8 晴 <=26 <=75 否 1\n",
307 "9 雨 <=26 >75 否 1\n",
308 "10 晴 <=26 <=75 是 1\n",
309 "11 多云 <=26 >75 是 1\n",
310 "12 多云 >26 <=75 否 1\n",
311 "13 雨 <=26 >75 是 0"
312 ]
313 },
314 "execution_count": 2,
315 "metadata": {},
316 "output_type": "execute_result"
317 }
318 ],
165 "execution_count": null,
166 "metadata": {},
167 "outputs": [],
319168 "source": [
320169 "# 原始数据\n",
321170 "datasets = [\n",
322 " ['晴',29,85,'否','0'],\n",
323 " ['晴',26,88,'是','0'],\n",
324 " ['多云',28,78,'否','1'],\n",
325 " ['雨',21,96,'否','1'],\n",
326 " ['雨',20,80,'否','1'],\n",
327 " ['雨',18,70,'是','0'],\n",
328 " ['多云',18,65,'是','1'],\n",
329 " ['晴',22,90,'否','0'],\n",
330 " ['晴',21,68,'否','1'],\n",
331 " ['雨',24,80,'否','1'],\n",
332 " ['晴',24,63,'是','1'],\n",
333 " ['多云',22,90,'是','1'],\n",
334 " ['多云',27,75,'否','1'],\n",
335 " ['雨',21,80,'是','0']\n",
171 " ['晴', 29, 85, '否', '0'],\n",
172 " ['晴', 26, 88, '是', '0'],\n",
173 " ['多云', 28, 78, '否', '1'],\n",
174 " ['雨', 21, 96, '否', '1'],\n",
175 " ['雨', 20, 80, '否', '1'],\n",
176 " ['雨', 18, 70, '是', '0'],\n",
177 " ['多云', 18, 65, '是', '1'],\n",
178 " ['晴', 22, 90, '否', '0'],\n",
179 " ['晴', 21, 68, '否', '1'],\n",
180 " ['雨', 24, 80, '否', '1'],\n",
181 " ['晴', 24, 63, '是', '1'],\n",
182 " ['多云', 22, 90, '是', '1'],\n",
183 " ['多云', 27, 75, '否', '1'],\n",
184 " ['雨', 21, 80, '是', '0']\n",
336185 "]\n",
337 "\n",
338186 "# 数据的列名\n",
339 "labels = ['天气','温度','湿度','是否有风','是否前往游乐场']\n",
340 "\n",
187 "labels = ['天气', '温度', '湿度', '是否有风', '是否前往游乐场']\n",
341188 "# 将湿度大小分为大于 75 和小于等于 75 这两个属性值,\n",
342189 "# 将温度大小分为大于 26 和小于等于 26 这两个属性值\n",
343190 "for i in range(len(datasets)):\n",
349196 " datasets[i][1] = '>26'\n",
350197 " else:\n",
351198 " datasets[i][1] = '<=26'\n",
352 "\n",
353199 "# 构建 dataframe 并查看数据\n",
354200 "df = pd.DataFrame(datasets, columns=labels)\n",
355 "df\n"
356 ]
357 },
358 {
359 "cell_type": "markdown",
360 "metadata": {},
361 "source": [
362 "## 2.2.2 构建决策树 \n",
363 "\n",
364 "**信息增益**是什么?\n",
365 "\n",
366 "**信息熵**是什么?"
367 ]
368 },
369 {
370 "cell_type": "markdown",
371 "metadata": {},
372 "source": [
373 "假设有 $K$ 个信息,其组成了集合样本 $D$ ,记第 $k$ 个信息发生的概率为$P_k(1≤k≤K)$。 \n",
374 "这 $K$ 个信息的信息熵: \n",
375 "$$E(D)=-\\sum_{k=1}^{K}p_k log_{2} p_k$$\n",
376 "\n",
377 "需要指出:**所有 $p_k$ 累加起来的和为1**。"
201 "df"
202 ]
203 },
204 {
205 "cell_type": "markdown",
206 "metadata": {},
207 "source": [
208 "## 2.2.2 构建决策树 "
209 ]
210 },
211 {
212 "cell_type": "markdown",
213 "metadata": {},
214 "source": [
215 "<center><video src=\"http://files.momodel.cn/decision_tree_entropy.mp4\" controls=\"controls\" width=800px></center>"
216 ]
217 },
218 {
219 "cell_type": "markdown",
220 "metadata": {},
221 "source": [
222 "**想一想**"
223 ]
224 },
225 {
226 "cell_type": "markdown",
227 "metadata": {},
228 "source": [
229 "信息熵和信息增益分别是什么?"
230 ]
231 },
232 {
233 "cell_type": "markdown",
234 "metadata": {},
235 "source": [
236 "假设有 $K$ 个信息,其组成了集合样本 $D$ ,记第 $k$ 个信息发生的概率为 $p_k(1≤k≤K)$。 \n",
237 "\n",
238 "这 $K$ 个信息的信息熵该如何计算? \n",
239 "\n",
240 "所有 $p_k$ 累加起来的和是多少?"
241 ]
242 },
243 {
244 "cell_type": "markdown",
245 "metadata": {},
246 "source": [
247 "**动手练**"
248 ]
249 },
250 {
251 "cell_type": "markdown",
252 "metadata": {},
253 "source": [
254 "根据公式编写代码计算信息熵"
378255 ]
379256 },
380257 {
390267 " :param count_dict: 每类样本及其对应数目的字典\n",
391268 " :return: 信息熵\n",
392269 " \"\"\"\n",
393 " # 使用信息熵公式计算\n",
394 " ent = -sum([(p / total_num) * log(p / total_num, 2) for p in count_dict.values() if p != 0])\n",
395 " # 避免 print 显示异常\n",
396 " if ent == 0:\n",
397 " ent = 0\n",
270 " # todo 使用公式计算信息熵\n",
271 " ent = \n",
398272 " # 返回信息熵,精确到小数点后 4 位\n",
399273 " return round(ent, 4)\n"
400274 ]
407281 "\n",
408282 "现在用**信息熵**来构建决策树。数据中 14 个样本分为 “游客来游乐场 (9 个样本)” 和 “游客不来游乐场( 5 个样本)” 两个类别,即 K = 2。\n",
409283 "\n",
410 "记 “游客来游乐场” 和 “游客不来游乐场” 的概率分别为 $p_1$ 和 $p_2$ ,显然 $p_1=\\frac{9}{14}$,$p_1=\\frac{5}{14}$,则这 14 个样本所蕴含的信息熵:\n",
284 "记 “游客来游乐场” 和 “游客不来游乐场” 的概率分别为 $p_1$ 和 $p_2$ ,显然 $p_1=\\frac{9}{14}$,$p_1=\\frac{5}{14}$,则这 14 个样本所蕴含的信息熵:\n",
411285 "\n",
412286 "$$E(D)=-\\sum_{k=1}^{2}p_{k}log_{2}{p_k}=-(\\frac{9}{14}×log_{2}{\\frac{9}{14}}+\\frac{5}{14}×log_{2}{\\frac{5}{14}})=0.940$$"
413287 ]
445319 "# 总样本数\n",
446320 "total_num = df.shape[0]\n",
447321 "# 每类样本及其对应数目的字典\n",
448 "count_dict = {'前往':df[df['是否前往游乐场']=='1'].shape[0], '不前往':df[df['是否前往游乐场']=='1'].shape[1]}\n",
322 "count_dict = {'前往': df[df['是否前往游乐场']=='1'].shape[0], '不前往': df[df['是否前往游乐场']=='1'].shape[1]}\n",
449323 "# 计算信息熵\n",
450324 "entropy = calc_entropy(total_num, count_dict)\n",
451325 "entropy\n"
455329 "cell_type": "markdown",
456330 "metadata": {},
457331 "source": [
332 "<center><video src=\"http://files.momodel.cn/decision_tree_build.mp4\" controls=\"controls\" width=800px></center>"
333 ]
334 },
335 {
336 "cell_type": "markdown",
337 "metadata": {},
338 "source": [
458339 "**计算天气状况所对应的信息熵**: \n",
459340 "天气状况的三个属性记为 $a_0=“晴”$ ,$a_1=“多云”$ ,$a_2=“雨”$ , \n",
460 "属性取值为 $a_i$ 对应分支节点所包含子样本集记为 $D_i$ ,该子样本集包含样本数量记为 $|D_i|$ 。"
341 "属性取值为 $a_i$ 对应分支节点所包含子样本集记为 $D_i$ ,该子样本集包含样本数量记为 $|D_i|$ 。"
461342 ]
462343 },
463344 {
482363 "cell_type": "markdown",
483364 "metadata": {},
484365 "source": [
485 "我们可以使用下面的写法,对 dataframe 进行多个条件的筛选。"
366 "现在,我们来编写代码进行完成上面的计算。\n",
367 "\n",
368 "首先,我们可以使用下面的写法,对 Dataframe 进行多个条件的筛选。"
486369 ]
487370 },
488371 {
575458 "cell_type": "markdown",
576459 "metadata": {},
577460 "source": [
461 "**动手练**"
462 ]
463 },
464 {
465 "cell_type": "markdown",
466 "metadata": {},
467 "source": [
578468 "使用上面的公式计算信息增益。"
579469 ]
580470 },
584474 "metadata": {},
585475 "outputs": [],
586476 "source": [
587 "# 信息增益\n",
588 "gain = entropy - (total_num_sun/total_num*ent_sun +\n",
589 " total_num_cloud/total_num*ent_cloud +\n",
590 " total_num_rain/total_num*ent_rain)\n",
477 "# todo 计算按天气状况分割的信息增益\n",
478 "gain = \n",
591479 "gain\n"
592480 ]
593481 },
595483 "cell_type": "markdown",
596484 "metadata": {},
597485 "source": [
486 "### 扩展内容\n",
487 "\n",
488 "**基尼指数**\n",
489 "\n",
490 "除了使用信息增益以外,我们也可以使用基尼指数来构建决策树。\n",
491 "\n",
492 "分类问题中,假设有 $K$ 个类,样本点属于第 $k$ 类的概率为 $p_{k}$,则概率分布的基尼指数定义为:\n",
493 "\n",
494 "$$\\operatorname{Gini}(p)=\\sum_{k=1}^{K} p_{k}\\left(1-p_{k}\\right)=1-\\sum_{k=1}^{K} p_{k}^{2}$$\n",
495 "\n",
496 "\n",
497 "对于给定的样本集合 $D$,其基尼指数为\n",
498 "\n",
499 "$$\\operatorname{Gini}(D)=1-\\sum_{k=1}^{K}\\left(\\frac{\\left|C_{k}\\right|}{|D|}\\right)^{2}$$\n",
500 "\n",
501 "这里,$C_{k}$ 是 $D$ 中属于第 $k$ 类的样本子集,$K$ 是类的个数。\n",
502 "\n",
503 "如果样本集合 $D$ 根据特征 $A$ 是否取某一可能值 $a$ 被分割为 $D_{1}$ 和 $D_{2}$ 两部分,即\n",
504 "\n",
505 "$$D_{1}=\\{(x, y) \\in D | A(x)=a\\}, \\quad D_{2}=D-D_{1}$$\n",
506 "\n",
507 "则在特征 $A$ 的条件下,集合 $D$ 的基尼指数定义为\n",
508 "\n",
509 "$$\\operatorname{Gini}(D, A)=\\frac{\\left|D_{1}\\right|}{|D|}\n",
510 "\\operatorname{Gini}\\left(D_{1}\\right)+\\frac{\\left|D_{2}\\right|}{|D|} \\operatorname{Gini}\\left(D_{2}\\right)$$\n",
511 "\n",
512 "基尼指数 $Gini(D)$ 表示集合 $D$ 的不确定性,基尼指数 $Gini(D, A)$ 表示经过分割后集合 $D$ 的不确定性。基尼指数值越大,样本集合的不确定性也就越大,这一点与信息熵相似。"
513 ]
514 },
515 {
516 "cell_type": "markdown",
517 "metadata": {},
518 "source": [
519 "**想一想**:"
520 ]
521 },
522 {
523 "cell_type": "markdown",
524 "metadata": {},
525 "source": [
526 "对于二分类问题,若样本点属于第 1 个类的概率是 $p$,则概率分布的基尼指数是多少?"
527 ]
528 },
529 {
530 "cell_type": "markdown",
531 "metadata": {},
532 "source": [
598533 "### 思考与练习 "
599534 ]
600535 },
602537 "cell_type": "markdown",
603538 "metadata": {},
604539 "source": [
605 "1. 分别将天气状况、温度高低、湿度大小、风力强弱作为分支点来构建决策树,查看信息增益\n"
606 ]
607 },
608 {
609 "cell_type": "code",
610 "execution_count": 4,
540 "1. 分别将天气状况、温度高低、湿度大小、风力强弱作为分支点来构建决策树,查看信息增益。\n"
541 ]
542 },
543 {
544 "cell_type": "code",
545 "execution_count": null,
611546 "metadata": {},
612547 "outputs": [],
613548 "source": [
616551 },
617552 {
618553 "cell_type": "code",
619 "execution_count": 5,
554 "execution_count": null,
620555 "metadata": {},
621556 "outputs": [],
622557 "source": [
625560 },
626561 {
627562 "cell_type": "code",
628 "execution_count": 6,
563 "execution_count": null,
629564 "metadata": {},
630565 "outputs": [],
631566 "source": [
662597 "cell_type": "markdown",
663598 "metadata": {},
664599 "source": [
665 "实际上是否如此呢?"
600 "实际上是否如此呢?你能否想到其他的切分方式?"
666601 ]
667602 },
668603 {
734669 "from sklearn import tree\n",
735670 "from sklearn.tree import DecisionTreeClassifier\n",
736671 "# 初始化模型,可以调整 max_depth 来观察模型的表现\n",
737 "clf = tree.DecisionTreeClassifier(random_state=42, max_depth=2)\n",
672 "# 也可以调整 criterion 为 gini 来使用 gini 指数构建决策树\n",
673 "clf = tree.DecisionTreeClassifier(criterion='entropy', max_depth=2)\n",
738674 "# 训练模型\n",
739675 "clf = clf.fit(X_train, y_train)\n"
740676 ]
769705 "cell_type": "markdown",
770706 "metadata": {},
771707 "source": [
772 "我们看模型在测试集上的表现"
708 "我们看模型在测试集上的表现。"
773709 ]
774710 },
775711 {
849785 },
850786 {
851787 "cell_type": "code",
852 "execution_count": null,
788 "execution_count": 1,
853789 "metadata": {},
854790 "outputs": [],
855791 "source": [
1010 "cell_type": "markdown",
1111 "metadata": {},
1212 "source": [
13 "决策树是一种通过**树形结构**进行分类的方法,使用层层推理来实现最终的分类。决策树由下面几种元素构成:\n",
14 "\n",
13 "决策树是一种通过**树形结构**进行分类的方法,使用层层推理来实现最终的分类。\n",
14 "\n",
15 "决策树由下面几种元素构成:\n",
1516 "+ 根节点:最顶层的分类条件。\n",
1617 "+ 决策节点(中间节点):中间分类条件。\n",
1718 "+ 叶子节点:代表标签类别。\n",
1819 "\n",
19 "\n",
20 "<img src=\"http://imgbed.momodel.cn//20200110170450.png\" width=500>\n",
21 "\n",
22 "\n",
23 "预测时,在树的内部节点处用某一属性值进行判断,根据判断结果决定进入哪个分支节点,直到到达叶节点处,得到分类结果。\n"
24 ]
25 },
26 {
27 "cell_type": "markdown",
28 "metadata": {},
29 "source": [
30 "上面的说法过于抽象,下面来看一个实际的例子。构建一棵结构简单的决策树,用于预测贷款用户是否具有偿还贷款的能力。\n",
20 "<img src=\"http://imgbed.momodel.cn//20200110170450.png\" width=400>\n",
21 "\n",
22 "预测时,在树的内部节点处用某一属性值进行判断,根据判断结果决定进入哪个分支节点,\n",
23 "\n",
24 "直到到达叶节点处,得到分类结果。\n"
25 ]
26 },
27 {
28 "cell_type": "markdown",
29 "metadata": {},
30 "source": [
31 "<br>"
32 ]
33 },
34 {
35 "cell_type": "markdown",
36 "metadata": {},
37 "source": [
38 "上面的说法过于抽象,下面来看一个实际的例子。\n",
39 "\n",
40 "构建一棵结构简单的决策树,用于预测贷款用户是否具有偿还贷款的能力。\n",
3141 "\n",
3242 "贷款用户主要具备三个属性:**是否拥有房产**,**是否结婚**,**平均月收入**。\n",
3343 "\n",
3444 "每一个内部节点都表示一个属性条件判断,叶子节点表示贷款用户是否具有偿还能力。\n",
35 "<img src=\"http://imgbed.momodel.cn//20200110171836.png\" width=500>\n",
45 "\n",
46 "<img src=\"http://imgbed.momodel.cn//20200110171836.png\" width=400>\n",
3647 "\n"
3748 ]
3849 },
4051 "cell_type": "markdown",
4152 "metadata": {},
4253 "source": [
43 "首先判断贷款用户是否拥有房产,如果用户拥有房产,则说明该用户具有偿还贷款的能力;否则需要判断该用户是否结婚,如果已经结婚则具有偿还贷款的能力;否则需要判断该用户的收入大小,如果该用户月收入小于 4K 元,则该用户不具有偿还贷款的能力,否则该用户是具有偿还能力的。\n",
44 "\n",
45 "现在有一个贷款用户A,其情况是月收入 3K、已经结婚、没有房产,那么他是否具有偿还贷款的能力呢?很显然,他是具有偿还贷款能力的。\n",
46 "\n",
47 "那么,上图中我们为啥要用“是否拥有房产”作根节点呢?可不可以用“是否结婚”和“平均月收入”做根节点呢?学完本章即可知道答案。"
54 "首先判断贷款用户是否拥有房产,\n",
55 "如果用户拥有房产,则说明该用户具有偿还贷款的能力;\n",
56 "否则需要判断该用户是否结婚,\n",
57 "如果已经结婚则具有偿还贷款的能力;\n",
58 "否则需要判断该用户的收入大小,\n",
59 "如果该用户月收入小于 4K 元,\n",
60 "则该用户不具有偿还贷款的能力,\n",
61 "否则该用户是具有偿还能力的。\n",
62 "\n",
63 "现在有一个贷款用户A,其情况是月收入 3K、已经结婚、没有房产,那么他是否具有偿还贷款的能力呢?\n",
64 "\n",
65 "很显然,他是具有偿还贷款能力的。\n",
66 "\n",
67 "那么,上图中我们为什么要用“是否拥有房产”作根节点呢?可不可以用“是否结婚”和“平均月收入”做根节点呢?\n",
68 "\n",
69 "学完本章即可知道答案。"
70 ]
71 },
72 {
73 "cell_type": "markdown",
74 "metadata": {},
75 "source": [
76 "<br>"
4877 ]
4978 },
5079 {
5281 "metadata": {},
5382 "source": [
5483 "## 2.2.1 决策树分类概念"
84 ]
85 },
86 {
87 "cell_type": "markdown",
88 "metadata": {},
89 "source": [
90 "<center><video src=\"http://files.momodel.cn/decision_tree_playground_demo.mp4\" controls=\"controls\" width=800px></center>"
91 ]
92 },
93 {
94 "cell_type": "markdown",
95 "metadata": {},
96 "source": [
97 "<br>"
5598 ]
5699 },
57100 {
89132 "cell_type": "markdown",
90133 "metadata": {},
91134 "source": [
92 "根节点是天气状况,具有雨、多云和晴三种属性取值。\n",
93 "+ 多云: 样本子集是 { 3, 7, 12, 13 } ,仅有“前往游乐场游玩”一个类别,即肯定去游乐场。 \n",
135 "<br>"
136 ]
137 },
138 {
139 "cell_type": "markdown",
140 "metadata": {},
141 "source": [
142 "根节点是天气状况,具有 **雨**、**多云** 和 **晴** 三种属性取值。\n",
143 "+ **多云**: 样本子集是 { 3, 7, 12, 13 } ,仅有“前往游乐场游玩”一个类别,即肯定去游乐场。 \n",
94144 " \n",
95145 " \n",
96 "+ 晴: 样本子集是 { 1, 2, 8, 9, 11 }\n",
146 "+ **晴**: 样本子集是 { 1, 2, 8, 9, 11 }\n",
97147 " + 湿度大于 75:样本子集为 { 1, 2, 8 },不前往游乐场。\n",
98148 " + 湿度不大于 75:样本子集 { 9, 11 },前往游乐场。\n",
99149 " \n",
100150 " \n",
101 "+ 雨:样本子集为 { 4, 5, 6, 10, 14 }\n",
151 "+ **雨**:样本子集为 { 4, 5, 6, 10, 14 }\n",
102152 " + 有风:样本子集 { 6, 14 },不去游乐场。\n",
103153 " + 无风:样本子集 { 4, 5, 10 },前往游乐场。\n",
104154 " "
118168 "cell_type": "markdown",
119169 "metadata": {},
120170 "source": [
121 "首先我们读取数据"
171 "<br>"
172 ]
173 },
174 {
175 "cell_type": "markdown",
176 "metadata": {},
177 "source": [
178 "下面,我们创建数据并进行一些预处理"
122179 ]
123180 },
124181 {
131188 "import pandas as pd\n",
132189 "import matplotlib.pyplot as plt\n",
133190 "%matplotlib inline\n",
134 "\n",
135191 "import math\n",
136192 "from math import log\n",
137193 "import warnings\n",
138 "warnings.filterwarnings(\"ignore\")\n"
194 "warnings.filterwarnings(\"ignore\")"
139195 ]
140196 },
141197 {
146202 "source": [
147203 "# 原始数据\n",
148204 "datasets = [\n",
149 " ['晴',29,85,'否','0'],\n",
150 " ['晴',26,88,'是','0'],\n",
151 " ['多云',28,78,'否','1'],\n",
152 " ['雨',21,96,'否','1'],\n",
153 " ['雨',20,80,'否','1'],\n",
154 " ['雨',18,70,'是','0'],\n",
155 " ['多云',18,65,'是','1'],\n",
156 " ['晴',22,90,'否','0'],\n",
157 " ['晴',21,68,'否','1'],\n",
158 " ['雨',24,80,'否','1'],\n",
159 " ['晴',24,63,'是','1'],\n",
160 " ['多云',22,90,'是','1'],\n",
161 " ['多云',27,75,'否','1'],\n",
162 " ['雨',21,80,'是','0']\n",
205 " ['晴', 29, 85, '否', '0'],\n",
206 " ['晴', 26, 88, '是', '0'],\n",
207 " ['多云', 28, 78, '否', '1'],\n",
208 " ['雨', 21, 96, '否', '1'],\n",
209 " ['雨', 20, 80, '否', '1'],\n",
210 " ['雨', 18, 70, '是', '0'],\n",
211 " ['多云', 18, 65, '是', '1'],\n",
212 " ['晴', 22, 90, '否', '0'],\n",
213 " ['晴', 21, 68, '否', '1'],\n",
214 " ['雨', 24, 80, '否', '1'],\n",
215 " ['晴', 24, 63, '是', '1'],\n",
216 " ['多云', 22, 90, '是', '1'],\n",
217 " ['多云', 27, 75, '否', '1'],\n",
218 " ['雨', 21, 80, '是', '0']\n",
163219 "]\n",
164 "\n",
165220 "# 数据的列名\n",
166 "labels = ['天气','温度','湿度','是否有风','是否前往游乐场']\n",
167 "\n",
221 "labels = ['天气', '温度', '湿度', '是否有风', '是否前往游乐场']\n",
168222 "# 将湿度大小分为大于 75 和小于等于 75 这两个属性值,\n",
169223 "# 将温度大小分为大于 26 和小于等于 26 这两个属性值\n",
170224 "for i in range(len(datasets)):\n",
176230 " datasets[i][1] = '>26'\n",
177231 " else:\n",
178232 " datasets[i][1] = '<=26'\n",
179 "\n",
180233 "# 构建 dataframe 并查看数据\n",
181234 "df = pd.DataFrame(datasets, columns=labels)\n",
182 "df\n"
183 ]
184 },
185 {
186 "cell_type": "markdown",
187 "metadata": {},
188 "source": [
189 "## 2.2.2 构建决策树 \n",
190 "\n",
235 "df"
236 ]
237 },
238 {
239 "cell_type": "markdown",
240 "metadata": {},
241 "source": [
242 "<br>"
243 ]
244 },
245 {
246 "cell_type": "markdown",
247 "metadata": {},
248 "source": [
249 "## 2.2.2 构建决策树 "
250 ]
251 },
252 {
253 "cell_type": "markdown",
254 "metadata": {},
255 "source": [
256 "<center><video src=\"http://files.momodel.cn/decision_tree_entropy.mp4\" controls=\"controls\" width=800px></center>"
257 ]
258 },
259 {
260 "cell_type": "markdown",
261 "metadata": {},
262 "source": [
191263 "**信息增益**用来衡量样本集合复杂度(不确定性)所减少的程度。 \n",
192264 "\n",
193265 "**信息熵**用来度量信息量的大小。从信息论的角度来看,对信息的度量等于计算信息不确定性的多少。 "
197269 "cell_type": "markdown",
198270 "metadata": {},
199271 "source": [
200 "假设有 $K$ 个信息,其组成了集合样本 $D$ ,记第 $k$ 个信息发生的概率为$P_k(1≤k≤K)$。 \n",
272 "假设有 $K$ 个信息,其组成了集合样本 $D$ ,记第 $k$ 个信息发生的概率为 $p_k(1≤k≤K)$。 \n",
201273 "这 $K$ 个信息的信息熵: \n",
202274 "$$E(D)=-\\sum_{k=1}^{K}p_k log_{2} p_k$$\n",
203275 "\n",
204276 "需要指出:**所有 $p_k$ 累加起来的和为1**。"
277 ]
278 },
279 {
280 "cell_type": "markdown",
281 "metadata": {},
282 "source": [
283 "<br>"
205284 ]
206285 },
207286 {
220299 " # 使用信息熵公式计算\n",
221300 " ent = -sum([(p / total_num) * log(p / total_num, 2) for p in count_dict.values() if p != 0])\n",
222301 " # 避免 print 显示异常\n",
223 " if ent == 0:\n",
302 " if not ent:\n",
224303 " ent = 0\n",
225 " # 返回信息熵,精确到小数点后 4 位\n",
226 " return round(ent, 4)\n"
304 " # 返回信息熵,精确到小数点后 3 位\n",
305 " return round(ent, 3)\n"
227306 ]
228307 },
229308 {
234313 "\n",
235314 "现在用**熵**来构建决策树。数据中 14 个样本分为 “游客来游乐场( 9 个样本)” 和 “游客不来游乐场( 5 个样本)” 两个类别,即 K = 2。\n",
236315 "\n",
237 "记 “游客来游乐场” 和 “游客不来游乐场” 的概率分别为 $p_1$ 和 $p_2$ ,显然 $p_1=\\frac{9}{14}$,$p_1=\\frac{5}{14}$ ,则这 14 个样本所蕴含的信息熵:\n",
316 "记 “游客来游乐场” 和 “游客不来游乐场” 的概率分别为 $p_1$ 和 $p_2$ ,显然 $p_1=\\frac{9}{14}$,$p_1=\\frac{5}{14}$ ,则这 14 个样本所蕴含的信息熵:\n",
238317 "\n",
239318 "$$E(D)=-\\sum_{k=1}^{2}p_{k}log_{2}{p_k}=-(\\frac{9}{14}×log_{2}{\\frac{9}{14}}+\\frac{5}{14}×log_{2}{\\frac{5}{14}})=0.940$$"
240319 ]
273352 "total_num = df.shape[0]\n",
274353 "\n",
275354 "# 每类样本及其对应数目的字典\n",
276 "count_dict = {'前往':df[df['是否前往游乐场']=='1'].shape[0], '不前往':df[df['是否前往游乐场']=='1'].shape[1]}\n",
355 "count_dict = {'前往': df[df['是否前往游乐场']=='1'].shape[0], '不前往': df[df['是否前往游乐场']=='1'].shape[1]}\n",
277356 "\n",
278357 "# 计算信息熵\n",
279358 "entropy = calc_entropy(total_num, count_dict)\n",
284363 "cell_type": "markdown",
285364 "metadata": {},
286365 "source": [
366 "<br>"
367 ]
368 },
369 {
370 "cell_type": "markdown",
371 "metadata": {},
372 "source": [
373 "<center><video src=\"http://files.momodel.cn/decision_tree_build.mp4\" controls=\"controls\" width=800px></center>"
374 ]
375 },
376 {
377 "cell_type": "markdown",
378 "metadata": {},
379 "source": [
380 "<br>"
381 ]
382 },
383 {
384 "cell_type": "markdown",
385 "metadata": {},
386 "source": [
287387 "**计算天气状况所对应的信息熵**: \n",
288388 "天气状况的三个属性记为 $a_0=“晴”$ ,$a_1=“多云”$ ,$a_2=“雨”$ , \n",
289 "属性取值为 $a_i$ 对应分支节点所包含子样本集记为 $D_i$ ,该子样本集包含样本数量记为 $|D_i|$ 。"
389 "属性取值为 $a_i$ 对应分支节点所包含子样本集记为 $D_i$ ,该子样本集包含样本数量记为 $|D_i|$ 。"
290390 ]
291391 },
292392 {
311411 "cell_type": "markdown",
312412 "metadata": {},
313413 "source": [
314 "我们可以使用下面的写法,对 dataframe 进行多个条件的筛选。"
414 "现在,我们来编写代码进行完成上面的计算。\n",
415 "\n",
416 "首先,我们可以使用下面的写法,对 Dataframe 进行多个条件的筛选。"
315417 ]
316418 },
317419 {
322424 "source": [
323425 "# 筛选出 天气为晴并且去游乐场的样本数据\n",
324426 "df[(df['天气']=='晴') & (df['是否前往游乐场']=='1')]\n"
427 ]
428 },
429 {
430 "cell_type": "markdown",
431 "metadata": {},
432 "source": [
433 "<br>"
434 ]
435 },
436 {
437 "cell_type": "markdown",
438 "metadata": {},
439 "source": [
440 "然后,我们是使用上面的筛选方法,分别计算不同天气下的信息熵。"
325441 ]
326442 },
327443 {
385501 "cell_type": "markdown",
386502 "metadata": {},
387503 "source": [
388 "计算天气状况的信息增益: \n",
504 "<br>"
505 ]
506 },
507 {
508 "cell_type": "markdown",
509 "metadata": {},
510 "source": [
511 "得到不同天气下的信息熵后,我们可以计算天气状况的信息增益: \n",
389512 "$$Gain(D,A)=E(D)-\\sum_{i}^{n}\\frac{|D_i|}{D}E(D)$$"
390513 ]
391514 },
404527 "cell_type": "markdown",
405528 "metadata": {},
406529 "source": [
407 "使用上面的公式计算信息增益。"
530 "我们来编写代码,使用上面的公式计算信息增益。"
408531 ]
409532 },
410533 {
424547 "cell_type": "markdown",
425548 "metadata": {},
426549 "source": [
550 "<br>"
551 ]
552 },
553 {
554 "cell_type": "markdown",
555 "metadata": {},
556 "source": [
557 "### 扩展内容\n",
558 "\n",
559 "**基尼指数**\n",
560 "\n",
561 "除了使用信息增益以外,我们也可以使用基尼指数来构建决策树。\n",
562 "\n",
563 "分类问题中,假设有 $K$ 个类,样本点属于第 $k$ 类的概率为 $p_{k}$,则概率分布的基尼指数定义为:\n",
564 "\n",
565 "$$\\operatorname{Gini}(p)=\\sum_{k=1}^{K} p_{k}\\left(1-p_{k}\\right)=1-\\sum_{k=1}^{K} p_{k}^{2}$$\n",
566 "\n",
567 "\n",
568 "对于给定的样本集合 $D$,其基尼指数为\n",
569 "\n",
570 "$$\\operatorname{Gini}(D)=1-\\sum_{k=1}^{K}\\left(\\frac{\\left|C_{k}\\right|}{|D|}\\right)^{2}$$\n",
571 "\n",
572 "这里,$C_{k}$ 是 $D$ 中属于第 $k$ 类的样本子集,$K$ 是类的个数。\n",
573 "\n",
574 "如果样本集合 $D$ 根据特征 $A$ 是否取某一可能值 $a$ 被分割为 $D_{1}$ 和 $D_{2}$ 两部分,即\n",
575 "\n",
576 "$$D_{1}=\\{(x, y) \\in D | A(x)=a\\}, \\quad D_{2}=D-D_{1}$$\n",
577 "\n",
578 "则在特征 $A$ 的条件下,集合 $D$ 的基尼指数定义为\n",
579 "\n",
580 "$$\\operatorname{Gini}(D, A)=\\frac{\\left|D_{1}\\right|}{|D|}\n",
581 "\\operatorname{Gini}\\left(D_{1}\\right)+\\frac{\\left|D_{2}\\right|}{|D|} \\operatorname{Gini}\\left(D_{2}\\right)$$\n",
582 "\n",
583 "基尼指数 $Gini(D)$ 表示集合 $D$ 的不确定性,基尼指数 $Gini(D, A)$ 表示经过分割后集合 $D$ 的不确定性。基尼指数值越大,样本集合的不确定性也就越大,这一点与信息熵相似。"
584 ]
585 },
586 {
587 "cell_type": "markdown",
588 "metadata": {},
589 "source": [
590 "**想一想**:"
591 ]
592 },
593 {
594 "cell_type": "markdown",
595 "metadata": {},
596 "source": [
597 "对于二分类问题,若样本点属于第 1 个类的概率是 $p$,则概率分布的基尼指数是多少?\n",
598 "\n",
599 "$$\\text { Gini }(p)=2 p(1-p)$$"
600 ]
601 },
602 {
603 "cell_type": "markdown",
604 "metadata": {},
605 "source": [
606 "<br>"
607 ]
608 },
609 {
610 "cell_type": "markdown",
611 "metadata": {},
612 "source": [
427613 "### 思考与练习 "
428614 ]
429615 },
431617 "cell_type": "markdown",
432618 "metadata": {},
433619 "source": [
434 "1. 分别将天气状况、温度高低、湿度大小、风力强弱作为分支点来构建决策树,查看信息增益\n"
620 "1. 在上面的游客去往游乐场的例子中,分别将温度高低、湿度大小、风力强弱作为分支点来构建决策树,查看信息增益。\n"
621 ]
622 },
623 {
624 "cell_type": "markdown",
625 "metadata": {},
626 "source": [
627 "计算按温度高低进行切分的信息增益。"
435628 ]
436629 },
437630 {
463656 ]
464657 },
465658 {
659 "cell_type": "markdown",
660 "metadata": {},
661 "source": [
662 "<br>"
663 ]
664 },
665 {
666 "cell_type": "markdown",
667 "metadata": {},
668 "source": [
669 "计算按湿度高低进行切分的信息增益。"
670 ]
671 },
672 {
466673 "cell_type": "code",
467674 "execution_count": null,
468675 "metadata": {},
491698 ]
492699 },
493700 {
701 "cell_type": "markdown",
702 "metadata": {},
703 "source": [
704 "<br>"
705 ]
706 },
707 {
708 "cell_type": "markdown",
709 "metadata": {},
710 "source": [
711 "计算按风力强弱进行切分的信息增益。"
712 ]
713 },
714 {
494715 "cell_type": "code",
495716 "execution_count": null,
496717 "metadata": {},
522743 "cell_type": "markdown",
523744 "metadata": {},
524745 "source": [
525 "2. 每朵鸢尾花有萼片长度、萼片宽度、花瓣长度、花瓣宽度四个特征。现在需要根据这四个特征将鸢尾花分为杂色鸢尾花、维吉尼亚鸢尾和山鸢尾三类,试构造决策树进行分类。\n",
746 "<br>"
747 ]
748 },
749 {
750 "cell_type": "markdown",
751 "metadata": {},
752 "source": [
753 "2. 每朵鸢尾花有**萼片长度**、**萼片宽度**、**花瓣长度**、**花瓣宽度**四个特征。现在需要根据这四个特征将鸢尾花分为**杂色鸢尾花**、**维吉尼亚鸢尾**和**山鸢尾**三类,试构造决策树进行分类。\n",
526754 "\n",
527755 "|序号|萼片长度|萼片宽度|花瓣长度|花瓣宽度|种类|\n",
528756 "|:--:|:--:|:--:|:--:|:--:|:--:|\n",
537765 "cell_type": "markdown",
538766 "metadata": {},
539767 "source": [
540 "观察上表中的五笔数据,我们可以看到 杂色鸢尾 和 维吉尼亚鸢尾 的花瓣宽度明显大于山鸢尾,所以可以通过判断花瓣宽度是否大于 0.7,来将山鸢尾区从其他两种鸢尾中区分出来。\n",
541 "\n",
542 "同时,杂色鸢尾 和 维吉尼亚鸢尾 的花瓣长度明显大于山鸢尾,所以也可以通过判断花瓣长度是否大于 2.4,来将山鸢尾区从其他两种鸢尾中区分出来。\n",
543 "\n",
544 "然后我们观察到 维吉尼亚鸢尾的花瓣长度明显大于杂色鸢尾,所以可以通过判断花瓣长度是否大于 4.75,来将杂色鸢尾和维吉尼亚鸢尾区分出来。"
545 ]
546 },
547 {
548 "cell_type": "markdown",
549 "metadata": {},
550 "source": [
551 "实际上是否如此呢?"
552 ]
553 },
554 {
555 "cell_type": "markdown",
556 "metadata": {},
557 "source": [
558 "上面的表格只是 Iris 数据集的一小部分,完整的数据集包含 150 个数据样本,分为 3 类,每类 50 个数据,每个数据包含 4 个属性。即花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性。\n",
768 "观察上表中的五笔数据,我们可以看到 **杂色鸢尾** 和 **维吉尼亚鸢尾** 的花瓣宽度明显大于 **山鸢尾**,所以可以通过判断花瓣宽度是否大于 0.7,来将 **山鸢尾** 从其他两种鸢尾中区分出来。\n",
769 "\n",
770 "同时,**杂色鸢尾** 和 **维吉尼亚鸢尾** 的花瓣长度明显大于 **山鸢尾**,所以也可以通过判断花瓣长度是否大于 2.4,来将 **山鸢尾 **从其他两种鸢尾中区分出来。\n",
771 "\n",
772 "然后我们观察到 **维吉尼亚鸢尾** 的花瓣长度明显大于 **杂色鸢尾**,所以可以通过判断花瓣长度是否大于 4.75,来将 **杂色鸢尾** 和 **维吉尼亚鸢尾**区分出来。"
773 ]
774 },
775 {
776 "cell_type": "markdown",
777 "metadata": {},
778 "source": [
779 "实际上是否如此呢?你能否想到其他的切分方式?"
780 ]
781 },
782 {
783 "cell_type": "markdown",
784 "metadata": {},
785 "source": [
786 "<br>"
787 ]
788 },
789 {
790 "cell_type": "markdown",
791 "metadata": {},
792 "source": [
793 "上面的表格只是 Iris 数据集的一小部分,完整的数据集包含 150 个数据样本,分为 3 类,每类 50 个数据,每个数据包含 4 个属性。即**花萼长度**,**花萼宽度**,**花瓣长度**,**花瓣宽度**4个属性。\n",
559794 "\n",
560795 "我们使用 sklearn 工具包来构建决策树模型,先导入数据集。"
561796 ]
579814 "cell_type": "markdown",
580815 "metadata": {},
581816 "source": [
582 "setosa 是山鸢尾,versicolor是杂色鸢尾,virginica是维吉尼亚鸢尾。\n",
583 "\n",
584 "sepal length, sepal width,petal length,petal width 分别是萼片长度,萼片宽度,花瓣长度,花瓣宽度。"
817 "setosa 是**山鸢尾**,versicolor是**杂色鸢尾**,virginica是**维吉尼亚鸢尾**。\n",
818 "\n",
819 "sepal length, sepal width,petal length,petal width 分别是**萼片长度**,**萼片宽度**,**花瓣长度**,**花瓣宽度**。"
585820 ]
586821 },
587822 {
608843 "cell_type": "markdown",
609844 "metadata": {},
610845 "source": [
846 "<br>"
847 ]
848 },
849 {
850 "cell_type": "markdown",
851 "metadata": {},
852 "source": [
611853 "接下来,我们在训练集数据上训练决策树模型。"
612854 ]
613855 },
619861 "source": [
620862 "from sklearn import tree\n",
621863 "from sklearn.tree import DecisionTreeClassifier\n",
622 "# 初始化模型,可以调整 max_depth 来观察模型的表现\n",
623 "clf = tree.DecisionTreeClassifier(random_state=42, max_depth=2)\n",
864 "# 初始化模型,可以调整 max_depth 来观察模型的表现, \n",
865 "# 也可以调整 criterion 为 gini 来使用 gini 指数构建决策树\n",
866 "clf = tree.DecisionTreeClassifier(criterion='entropy', max_depth=2)\n",
624867 "# 训练模型\n",
625868 "clf = clf.fit(X_train, y_train)\n"
626869 ]
33 "cell_type": "markdown",
44 "metadata": {},
55 "source": [
6 "# 2.3 回归分析\n",
7 "\n",
6 "# 2.3 回归分析"
7 ]
8 },
9 {
10 "cell_type": "markdown",
11 "metadata": {},
12 "source": [
813 "**回归分析**:分析不同变量之间存在关系的研究。 \n",
914 "**回归模型**:刻画不同变量之间关系的模型。"
1015 ]
1318 "cell_type": "markdown",
1419 "metadata": {},
1520 "source": [
16 "## 2.3.1 回归分析的基本概念\n",
17 "\n",
21 "<center><video src=\"http://files.momodel.cn/regression_intro.mp4\" controls=\"controls\" width=800px></center>"
22 ]
23 },
24 {
25 "cell_type": "markdown",
26 "metadata": {},
27 "source": [
28 "## 2.3.1 回归分析的基本概念"
29 ]
30 },
31 {
32 "cell_type": "markdown",
33 "metadata": {},
34 "source": [
35 "<center><video src=\"http://files.momodel.cn/regression_basic_concept.mp4\" controls=\"controls\" width=800px></center>"
36 ]
37 },
38 {
39 "cell_type": "markdown",
40 "metadata": {},
41 "source": [
1842 "**数据**:下表给出了莫纳罗亚山从 1970 年到 2005 年间每 5 年的二氧化碳浓度,单位是百万分比浓度(parts per million,简称ppm)\n",
1943 "\n",
2044 "<table>\n",
4670 "</table>\n",
4771 "\n",
4872 "\n",
49 "**目标**:分析时间年份和二氧化碳浓度之间的关联关系,由此预测2010年二氧化碳浓度。\n"
73 "**目标**:分析时间年份和二氧化碳浓度之间的关联关系,由此预测2010年二氧化碳浓度。"
5074 ]
5175 },
5276 {
84108 "cell_type": "markdown",
85109 "metadata": {},
86110 "source": [
87 "## 2.3.2 回归分析中参数计算\n",
88 "\n",
89 "最简单的线性回归是**一元线性回归模型**,只包含一个自变量 $x$ 和一个因变量 $y$,并且假定自变量和因变量之间存在 $y=ax+b$ 的线性关系。求解参数 $a$ 和 $b$,需要给定若干组 $(x,y)$ 数据,然后从这些数据出发来计算参数 $a$ 和 $b$。\n"
111 "## 2.3.2 回归分析中参数计算"
112 ]
113 },
114 {
115 "cell_type": "markdown",
116 "metadata": {},
117 "source": [
118 "<center><video src=\"http://files.momodel.cn/regression_solve_params.mp4\" controls=\"controls\" width=800px></center>"
119 ]
120 },
121 {
122 "cell_type": "markdown",
123 "metadata": {},
124 "source": [
125 "最简单的线性回归是**一元线性回归模型**,只包含一个自变量 $x$ 和一个因变量 $y$,并且假定自变量和因变量之间存在 $y=ax+b$ 的线性关系。\n",
126 "\n",
127 "求解参数 $a$ 和 $b$,需要给定若干组 $(x,y)$ 数据,然后从这些数据出发来计算参数 $a$ 和 $b$。"
90128 ]
91129 },
92130 {
95133 "source": [
96134 "在一元线性回归模型中,最关键的问题是如何计算参数 $a$ 和参数 $b$ 使误差最小化。\n",
97135 "\n",
98 "最拟合直线 $y=ax+b$ 应该与这 8 组样本数据点距离都很近,最好的情况是这些样本数据点都在该直线上(不现实),让所有样本数据点离直线尽可能的近(被定义为预测数值和实际数值之间的差)。\n",
99 "\n",
100 "**预测值**:\n",
101 "\n",
102 "**真实值**:\n",
103 "\n",
104 "**残差**:\n"
105 ]
106 },
107 {
108 "cell_type": "markdown",
109 "metadata": {},
110 "source": [
111 "我们根据公式编写如下的方法来求解 $a$ 和 $b$。"
136 "最拟合直线 $y=ax+b$ 应该与这 8 组样本数据点距离都很近,最好的情况是这些样本数据点都在该直线上(不现实),让所有样本数据点离直线尽可能的近(被定义为预测数值和实际数值之间的差)。"
137 ]
138 },
139 {
140 "cell_type": "markdown",
141 "metadata": {},
142 "source": [
143 "**想一想**"
144 ]
145 },
146 {
147 "cell_type": "markdown",
148 "metadata": {},
149 "source": [
150 "预测值,真实值,残差分别是什么?"
151 ]
152 },
153 {
154 "cell_type": "markdown",
155 "metadata": {},
156 "source": [
157 "**动手练**"
158 ]
159 },
160 {
161 "cell_type": "markdown",
162 "metadata": {},
163 "source": [
164 "根据书中的计算公式编写代码来求解 $a$ 和 $b$。"
112165 ]
113166 },
114167 {
124177 " :param y: np array 格式的因变量\n",
125178 " :return: 系数 a 和 b\n",
126179 " \"\"\"\n",
127 " pass\n",
180 " # todo 完成求解参数 a,b 的代码\n",
128181 " return a, b\n",
182 "\n",
129183 "a, b = cal_a_b(x, y)\n",
130184 "print(a, b)"
131185 ]
134188 "cell_type": "markdown",
135189 "metadata": {},
136190 "source": [
137 "综上:得到的预测莫纳罗亚山地区二氧化碳浓度的一元线性回归模型为:。 \n",
138 "我们可以据此绘制出拟合直线。"
191 "综上:得到的预测莫纳罗亚山地区二氧化碳浓度的一元线性回归模型是什么? "
192 ]
193 },
194 {
195 "cell_type": "markdown",
196 "metadata": {},
197 "source": [
198 "根据求解结果绘制出拟合直线。"
139199 ]
140200 },
141201 {
161221 "cell_type": "markdown",
162222 "metadata": {},
163223 "source": [
164 "然后我们可以对该地区1970年之前和2005年之后的二氧化碳浓度进行估算。"
224 "然后对该地区1970年之前和2005年之后的二氧化碳浓度进行估算。"
165225 ]
166226 },
167227 {
178238 "cell_type": "markdown",
179239 "metadata": {},
180240 "source": [
181 "填写你的最终的预测结果在下表中: \n",
241 "将你的最终的预测结果填写在下表中: \n",
182242 "\n",
183243 "<table>\n",
184244 "<tbody>\n",
220280 },
221281 {
222282 "cell_type": "code",
223 "execution_count": 1,
224 "metadata": {},
225 "outputs": [
226 {
227 "name": "stdout",
228 "output_type": "stream",
229 "text": [
230 "[[1.53438095]]\n",
231 "[-2698.87714286]\n"
232 ]
233 }
234 ],
283 "execution_count": null,
284 "metadata": {},
285 "outputs": [],
235286 "source": [
236287 "# 导入工具包\n",
237288 "import numpy as np\n",
277328 "\n",
278329 "<img src=\"http://imgbed.momodel.cn//20200115014016.png\" width=350>\n",
279330 " \n",
280 "其中 J 是代价函数,$\\theta_{0},\\theta_{1}$ 是待求参数, α 是学习率,它决定了我们沿着能让代价函数下降程度最大的方向向下迈出的步子有多大。 "
331 "其中 $J$ 是代价函数,$\\theta_{0},\\theta_{1}$ 是待求参数, $α$ 是学习率,它决定了我们沿着能让代价函数下降程度最大的方向向下迈出的步子有多大。 "
281332 ]
282333 },
283334 {
309360 "cell_type": "markdown",
310361 "metadata": {},
311362 "source": [
312 "## 探究莫纳罗亚山地区二氧化碳与温度之间的关系\n",
363 "#### 探究莫纳罗亚山地区二氧化碳与温度之间的关系\n",
313364 "\n",
314365 "该地区 1970 年到 2005 年间每 5 年的二氧化碳浓度以及全球温度(相对于 1961 - 1990 年经过平滑处理的平均温度增长量)\n",
315366 "\n",
561612 "cell_type": "markdown",
562613 "metadata": {},
563614 "source": [
564 "**问题 1**:观察上图,𝑧 与时间 𝑥 之间是否存在线性关系?如果是,我们用上面写好的方法来求解系数。"
615 "**问题 1**:观察上图,$z$ 与时间 $x$ 之间是否存在线性关系?如果是,我们用上面写好的方法来求解系数。"
565616 ]
566617 },
567618 {
33 "cell_type": "markdown",
44 "metadata": {},
55 "source": [
6 "# 2.3 回归分析\n",
7 "\n",
6 "# 2.3 回归分析"
7 ]
8 },
9 {
10 "cell_type": "markdown",
11 "metadata": {},
12 "source": [
813 "**回归分析**:分析不同变量之间存在关系的研究。 \n",
914 "**回归模型**:刻画不同变量之间关系的模型。"
1015 ]
1318 "cell_type": "markdown",
1419 "metadata": {},
1520 "source": [
16 "## 2.3.1 回归分析的基本概念\n",
17 "\n",
21 "<center><video src=\"http://files.momodel.cn/regression_intro.mp4\" controls=\"controls\" width=800px></center>"
22 ]
23 },
24 {
25 "cell_type": "markdown",
26 "metadata": {},
27 "source": [
28 "<br>"
29 ]
30 },
31 {
32 "cell_type": "markdown",
33 "metadata": {},
34 "source": [
35 "## 2.3.1 回归分析的基本概念"
36 ]
37 },
38 {
39 "cell_type": "markdown",
40 "metadata": {},
41 "source": [
42 "<center><video src=\"http://files.momodel.cn/regression_basic_concept.mp4\" controls=\"controls\" width=800px></center>"
43 ]
44 },
45 {
46 "cell_type": "markdown",
47 "metadata": {},
48 "source": [
49 "<br>"
50 ]
51 },
52 {
53 "cell_type": "markdown",
54 "metadata": {},
55 "source": [
56 "#### 探究莫纳罗亚山地区二氧化碳与温度之间的关系"
57 ]
58 },
59 {
60 "cell_type": "markdown",
61 "metadata": {},
62 "source": [
1863 "**数据**:下表给出了莫纳罗亚山从 1970 年到 2005 年间每 5 年的二氧化碳浓度,单位是百万分比浓度(parts per million,简称ppm)\n",
1964 "\n",
2065 "<table>\n",
4691 "</table>\n",
4792 "\n",
4893 "\n",
49 "**目标**:分析时间年份和二氧化碳浓度之间的关联关系,由此预测2010年二氧化碳浓度。\n"
94 "**目标**:分析时间年份和二氧化碳浓度之间的关联关系,由此预测2010年二氧化碳浓度。"
5095 ]
5196 },
5297 {
84129 "cell_type": "markdown",
85130 "metadata": {},
86131 "source": [
87 "## 2.3.2 回归分析中参数计算\n",
88 "\n",
89 "最简单的线性回归是**一元线性回归模型**,只包含一个自变量 $x$ 和一个因变量 $y$,并且假定自变量和因变量之间存在 $y=ax+b$ 的线性关系。求解参数 $a$ 和 $b$,需要给定若干组 $(x,y)$ 数据,然后从这些数据出发来计算参数 $a$ 和 $b$。\n"
132 "<br>"
133 ]
134 },
135 {
136 "cell_type": "markdown",
137 "metadata": {},
138 "source": [
139 "## 2.3.2 回归分析中参数计算"
140 ]
141 },
142 {
143 "cell_type": "markdown",
144 "metadata": {},
145 "source": [
146 "<center><video src=\"http://files.momodel.cn/regression_solve_params.mp4\" controls=\"controls\" width=800px></center>"
147 ]
148 },
149 {
150 "cell_type": "markdown",
151 "metadata": {},
152 "source": [
153 "<br>"
154 ]
155 },
156 {
157 "cell_type": "markdown",
158 "metadata": {},
159 "source": [
160 "最简单的线性回归是**一元线性回归模型**,只包含一个自变量 $x$ 和一个因变量 $y$,并且假定自变量和因变量之间存在 $y=ax+b$ 的线性关系。\n",
161 "\n",
162 "求解参数 $a$ 和 $b$,需要给定若干组 $(x,y)$ 数据,然后从这些数据出发来计算参数 $a$ 和 $b$。"
90163 ]
91164 },
92165 {
112185 "$a=\\frac{x_{1}y_{1}+x_2y_2+...+x_8y_8-8\\overline{x}\\overline{y}}{x_{1}^{2}+x_{2}^{2}+...+x_{8}^{2}-8\\overline{x}^2}=1.5344$\n",
113186 "\n",
114187 "$b = \\overline{y}-a\\overline{x}=-2698.9$"
188 ]
189 },
190 {
191 "cell_type": "markdown",
192 "metadata": {},
193 "source": [
194 "<br>"
115195 ]
116196 },
117197 {
149229 " n - len(x) * x_avarage * x_avarage)\n",
150230 " b = y_avarage - a * x_avarage\n",
151231 " return a, b\n",
232 "\n",
152233 "a, b = cal_a_b(x, y)\n",
153234 "print(a, b)"
154235 ]
157238 "cell_type": "markdown",
158239 "metadata": {},
159240 "source": [
160 "综上:预测莫纳罗亚山地区二氧化碳浓度的一元线性回归模型为:$y=1.5344x-2698.9$。 \n",
241 "综上:预测莫纳罗亚山地区二氧化碳浓度的一元线性回归模型为:$y=1.5344x-2698.9$。 "
242 ]
243 },
244 {
245 "cell_type": "markdown",
246 "metadata": {},
247 "source": [
248 "<br>"
249 ]
250 },
251 {
252 "cell_type": "markdown",
253 "metadata": {},
254 "source": [
161255 "我们可以据此绘制出拟合直线。"
162256 ]
163257 },
230324 "cell_type": "markdown",
231325 "metadata": {},
232326 "source": [
327 "<br>"
328 ]
329 },
330 {
331 "cell_type": "markdown",
332 "metadata": {},
333 "source": [
233334 "### 扩展内容\n",
234335 "\n",
235336 "**1.使用 sklearn 工具包构建回归模型**"
292393 "\n",
293394 "<img src=\"http://imgbed.momodel.cn//20200115014016.png\" width=350>\n",
294395 " \n",
295 "其中 J 是代价函数,$\\theta_{0},\\theta_{1}$ 是待求参数, α 是学习率,它决定了我们沿着能让代价函数下降程度最大的方向向下迈出的步子有多大。 "
396 "其中 $J$ 是代价函数,$\\theta_{0},\\theta_{1}$ 是待求参数, $α$ 是学习率,它决定了我们沿着能让代价函数下降程度最大的方向向下迈出的步子有多大。"
296397 ]
297398 },
298399 {
324425 "cell_type": "markdown",
325426 "metadata": {},
326427 "source": [
327 "## 探究莫纳罗亚山地区二氧化碳与温度之间的关系\n",
428 "#### 探究莫纳罗亚山地区二氧化碳与温度之间的关系\n",
328429 "\n",
329430 "该地区 1970 年到 2005 年间每 5 年的二氧化碳浓度以及全球温度(相对于 1961 - 1990 年经过平滑处理的平均温度增长量)\n",
330431 "\n",
575676 "cell_type": "markdown",
576677 "metadata": {},
577678 "source": [
578 "我们看到 𝑧 与时间 𝑥 之间是否存在线性关系,我们用上面写好的方法来求解系数。"
679 "我们看到 $z$ 与时间 $x$ 之间是否存在线性关系,我们用上面写好的方法来求解系数。"
579680 ]
580681 },
581682 {
33 "cell_type": "markdown",
44 "metadata": {},
55 "source": [
6 "# 2.4 贝叶斯分析\n",
6 "# 2.4 贝叶斯分析"
7 ]
8 },
9 {
10 "cell_type": "markdown",
11 "metadata": {},
12 "source": [
713 "贝叶斯分析是一种根据概率统计知识对数据进行分析的方法,属于统计学分类的范畴。"
814 ]
915 },
1117 "cell_type": "markdown",
1218 "metadata": {},
1319 "source": [
14 "## 2.4.1 贝叶斯公式\n",
15 "\n",
16 "- **频率学派**\n",
17 "- **贝叶斯学派**\n",
18 "\n",
19 "贝叶斯概率计算公式表达:\n",
20 "后验概率 = \n",
21 "\n",
22 "**条件概率**:\n",
23 "\n",
24 "$P(A|B)$:\n",
25 "\n",
26 "$P(A|B)$:\n",
27 "\n",
28 "$P(B|A)$:\n",
29 "\n",
30 "$P(B|A)$:\n",
31 "\n",
32 "由于:\n",
33 "\n",
34 "$$P(B|A){P(A)}=P(A|B){P(B)}={P(A ∩ B)}$$\n",
35 "\n",
36 "可得**贝叶斯公式**:\n",
20 "## 2.4.1 贝叶斯公式"
21 ]
22 },
23 {
24 "cell_type": "markdown",
25 "metadata": {},
26 "source": [
27 "<center><video src=\"http://files.momodel.cn/bayes_theorem.mp4\" controls=\"controls\" width=800px></center>"
28 ]
29 },
30 {
31 "cell_type": "markdown",
32 "metadata": {},
33 "source": [
34 "<br>"
35 ]
36 },
37 {
38 "cell_type": "markdown",
39 "metadata": {},
40 "source": [
41 "**想一想**"
42 ]
43 },
44 {
45 "cell_type": "markdown",
46 "metadata": {},
47 "source": [
48 "频率学派 和 贝叶斯学派各自的主张是什么?\n",
49 "\n",
50 "**贝叶斯公式**:\n",
3751 "\n",
3852 "$$P(A|B) = \\frac{P(B|A)P(A)}{P(B)}$$\n",
3953 "\n",
40 "其中:\n",
41 "- $P(A)$ 是事件 $A$ 发生的先验概率,与事件 $B$ 是否发生无关;\n",
42 "- $P(B|A)$是事件 $A$ 发生前提下,事件 $B$ 发生的概率,也称为**似然概率**; \n",
43 "- $P(B)$ 是事件 $B$ 发生的先验概率,也称为**标准化常量**;\n",
44 "- $P(A|B)$是事件 $B$ 发生前提下,事件 $A$ 发生的概率,也是 $A$ 的后验概率。 "
45 ]
46 },
47 {
48 "cell_type": "markdown",
49 "metadata": {},
50 "source": [
51 "## 2.4.2 贝叶斯推断\n",
54 "中的各项分别代表什么意思? "
55 ]
56 },
57 {
58 "cell_type": "markdown",
59 "metadata": {},
60 "source": [
61 "## 2.4.2 贝叶斯推断"
62 ]
63 },
64 {
65 "cell_type": "markdown",
66 "metadata": {},
67 "source": [
5268 "贝叶斯推断是一种基于贝叶斯公式进行分析的统计学方法。"
5369 ]
5470 },
5672 "cell_type": "markdown",
5773 "metadata": {},
5874 "source": [
59 "小例子:根据邮件中的 “红包” 字样判别该邮件是不是垃圾邮件"
75 "<center><video src=\"http://files.momodel.cn/bayes_inference.mp4\" controls=\"controls\" width=800px></center>"
76 ]
77 },
78 {
79 "cell_type": "markdown",
80 "metadata": {},
81 "source": [
82 "#### 根据邮件中的 “红包” 字样判别该邮件是不是垃圾邮件\n"
6083 ]
6184 },
6285 {
112135 "source": [
113136 "### 扩展内容\n",
114137 "\n",
115 "**化验结果为阳性就代表你真的患病了吗?**\n",
116 "\n",
117 "某同学 A 身体不舒服,去医院作了验血检查,看他是否得了 X 疾病,检查结果居然为阳性,他吓了一跳,赶紧上网查询。他看到网上有资料说,实验总是有误差的,这种实验有“百分之一的假阳性率和百分之一的假阴性率”。也就是说,在确实得了 X 疾病的人里面, 会有 1% 的人是假阴性,99%的人是真阳性, 也就是会有 。而没得病的人去做检查,有 1% 的人是假阳性,99% 的人是真阴性。 于是,他认为,既然误检的概率这么低,那么他确实患病的概率应该是非常高的。\n",
138 "#### 化验结果为阳性就代表你真的患病了吗?\n",
139 "\n",
140 "某同学 A 身体不舒服,去医院作了验血检查,看他是否得了 X 疾病,检查结果居然为阳性,他吓了一跳,赶紧上网查询。他看到网上有资料说,实验总是有误差的,这种实验有“百分之一的假阳性率和百分之一的假阴性率”。也就是说,在确实得了 X 疾病的人里面, 会有 1% 的人是假阴性,99%的人是真阳性, 也就是会有 1% 的几率被误诊为没病。而没得病的人去做检查,有 1% 的人是假阳性,99% 的人是真阴性,也就是会有 1% 的几率被误诊为有病。 于是,他认为,既然误检的概率这么低,那么他确实患病的概率应该是非常高的。\n",
118141 "\n",
119142 "可是,医生却告诉他,他被感染的概率只有 0.09 左右。这是怎么回事呢?\n",
120143 "\n",
127150 "cell_type": "markdown",
128151 "metadata": {},
129152 "source": [
153 "**动手练**"
154 ]
155 },
156 {
157 "cell_type": "markdown",
158 "metadata": {},
159 "source": [
130160 "$A$ : 普通人患 X 病\n",
131161 "\n",
132162 "$B$ : 化验结果为阳性\n",
139169 "\n",
140170 "$P(B|A)$:一个人患 X 病,其检测结果为阳性的概率, 99%\n",
141171 "\n",
142 "根据**贝叶斯公式**:\n",
143 "\n",
144 "$$\n",
145 "\\begin{align}\n",
146 "&P(A|B)=\\frac{P(B|A){P(A)}}{P(B)}\\\\\n",
147 "=&\\frac{99\\%*(1/1000)}{99\\%*(1/1000) + 1\\%*(999/1000)}\\\\\n",
148 "=&\\frac{99}{1098}\\\\\n",
149 "≈ & 9\\%\n",
150 "\\end{align}\n",
151 "$$\n"
152 ]
153 },
154 {
155 "cell_type": "markdown",
156 "metadata": {},
157 "source": [
158 "## 2.4.3 朴素贝叶斯分类器 \n",
159 "一种常用的分类算法,其基本假设为 。"
160 ]
161 },
162 {
163 "cell_type": "markdown",
164 "metadata": {},
165 "source": [
166 "小例子:预测同学会不会在某店铺订餐。\n",
167 "\n",
168 "**目标**:根据某同学的订单记录,如果向他推荐一家“价位低、口味偏甜、距离远”的店铺,判断他会下单吗?\n",
169 "\n",
172 "根据**贝叶斯公式**,计算 $P(A|B)$\n"
173 ]
174 },
175 {
176 "cell_type": "markdown",
177 "metadata": {},
178 "source": [
179 "## 2.4.3 朴素贝叶斯分类器 "
180 ]
181 },
182 {
183 "cell_type": "markdown",
184 "metadata": {},
185 "source": [
186 "<center><video src=\"http://files.momodel.cn/bayes_naive.mp4\" controls=\"controls\" width=800px></center>"
187 ]
188 },
189 {
190 "cell_type": "markdown",
191 "metadata": {},
192 "source": [
193 "**想一想**"
194 ]
195 },
196 {
197 "cell_type": "markdown",
198 "metadata": {},
199 "source": [
200 "朴素贝叶斯分类器做为一种常用的分类算法,其基本假设是什么?"
201 ]
202 },
203 {
204 "cell_type": "markdown",
205 "metadata": {},
206 "source": [
207 "#### 根据某同学的订单记录,判断其是否会对某店铺下单"
208 ]
209 },
210 {
211 "cell_type": "markdown",
212 "metadata": {},
213 "source": [
170214 "**数据**:该同学的下单记录如下\n",
171215 "\n",
172216 "|店铺价位|店铺口味|店铺距离|是否下单|\n",
178222 "|低|偏甜|近|是|\n",
179223 "|低|偏甜|近|是|\n",
180224 "|低|清淡|远|否|\n",
181 "|低|偏辣|远|是|\n"
225 "|低|偏辣|远|是|\n",
226 "\n",
227 "**目标**:根据某同学的订单记录,如果向他推荐一家“价位低、口味偏甜、距离远”的店铺,判断他会下单吗?"
182228 ]
183229 },
184230 {
259305 "$$\n",
260306 "\n",
261307 "上述两个计算公式分母相同,对计算结果不影响,因此就从计算过程中略去了。"
308 ]
309 },
310 {
311 "cell_type": "markdown",
312 "metadata": {},
313 "source": [
314 "<br>"
262315 ]
263316 },
264317 {
360413 "cell_type": "markdown",
361414 "metadata": {},
362415 "source": [
416 "<br>"
417 ]
418 },
419 {
420 "cell_type": "markdown",
421 "metadata": {},
422 "source": [
363423 "3.根据 **MNIST** 训练集训练朴素贝叶斯分类器"
364424 ]
365425 },
445505 "\n",
446506 " print(\"数字 %s 的样本个数:%4s,预测正确的个数:%4s,准确率:%.4s%%\" % (\n",
447507 " i, class_num[i], predict_num[i][i], class_accuracy[i]))"
508 ]
509 },
510 {
511 "cell_type": "markdown",
512 "metadata": {},
513 "source": [
514 "<br>"
448515 ]
449516 },
450517 {
33 "cell_type": "markdown",
44 "metadata": {},
55 "source": [
6 "# 2.4 贝叶斯分析\n",
6 "# 2.4 贝叶斯分析"
7 ]
8 },
9 {
10 "cell_type": "markdown",
11 "metadata": {},
12 "source": [
713 "贝叶斯分析是一种根据概率统计知识对数据进行分析的方法,属于统计学分类的范畴。"
814 ]
915 },
1117 "cell_type": "markdown",
1218 "metadata": {},
1319 "source": [
14 "## 2.4.1 贝叶斯公式\n",
15 "\n",
20 "## 2.4.1 贝叶斯公式"
21 ]
22 },
23 {
24 "cell_type": "markdown",
25 "metadata": {},
26 "source": [
27 "<center><video src=\"http://files.momodel.cn/bayes_theorem.mp4\" controls=\"controls\" width=800px></center>"
28 ]
29 },
30 {
31 "cell_type": "markdown",
32 "metadata": {},
33 "source": [
34 "<br>"
35 ]
36 },
37 {
38 "cell_type": "markdown",
39 "metadata": {},
40 "source": [
1641 "- **频率学派**:从历史数据中计算某个事件的概率,认为只要采样足够多,则事件发生的频率就可以无限逼近真实概率。\n",
1742 "- **贝叶斯学派**:认为某个事件发生的概率不仅与先前这个事件发生的概率相关(称为**先验概率**),也与后期计算该事件概率时所观测的“新近”信息有关(称为**似然概率**)\n",
1843 "\n",
4873 "cell_type": "markdown",
4974 "metadata": {},
5075 "source": [
51 "## 2.4.2 贝叶斯推断\n",
76 "<br>"
77 ]
78 },
79 {
80 "cell_type": "markdown",
81 "metadata": {},
82 "source": [
83 "## 2.4.2 贝叶斯推断"
84 ]
85 },
86 {
87 "cell_type": "markdown",
88 "metadata": {},
89 "source": [
5290 "贝叶斯推断是一种基于贝叶斯公式进行分析的统计学方法。"
5391 ]
5492 },
5694 "cell_type": "markdown",
5795 "metadata": {},
5896 "source": [
59 "小例子:根据邮件中的 “红包” 字样判别该邮件是不是垃圾邮件"
97 "<center><video src=\"http://files.momodel.cn/bayes_inference.mp4\" controls=\"controls\" width=800px></center>"
98 ]
99 },
100 {
101 "cell_type": "markdown",
102 "metadata": {},
103 "source": [
104 "<br>"
105 ]
106 },
107 {
108 "cell_type": "markdown",
109 "metadata": {},
110 "source": [
111 "根据邮件中的 “红包” 字样判别该邮件是不是垃圾邮件"
60112 ]
61113 },
62114 {
110162 "cell_type": "markdown",
111163 "metadata": {},
112164 "source": [
165 "<br>"
166 ]
167 },
168 {
169 "cell_type": "markdown",
170 "metadata": {},
171 "source": [
113172 "### 扩展内容\n",
114173 "\n",
115 "**化验结果为阳性就代表你真的患病了吗?**\n",
116 "\n",
117 "某同学 A 身体不舒服,去医院作了验血检查,看他是否得了 X 疾病,检查结果居然为阳性,他吓了一跳,赶紧上网查询。他看到网上有资料说,实验总是有误差的,这种实验有“百分之一的假阳性率和百分之一的假阴性率”。也就是说,在确实得了 X 疾病的人里面, 会有 1% 的人是假阴性,99%的人是真阳性, 也就是会有 。而没得病的人去做检查,有 1% 的人是假阳性,99% 的人是真阴性。 于是,他认为,既然误检的概率这么低,那么他确实患病的概率应该是非常高的。\n",
174 "#### 化验结果为阳性就代表你真的患病了吗?\n",
175 "\n",
176 "某同学 A 身体不舒服,去医院作了验血检查,看他是否得了 X 疾病,检查结果居然为阳性,他吓了一跳,赶紧上网查询。他看到网上有资料说,实验总是有误差的,这种实验有“百分之一的假阳性率和百分之一的假阴性率”。也就是说,在确实得了 X 疾病的人里面, 会有 1% 的人是假阴性,99%的人是真阳性, 也就是会有 1% 的几率被误诊为没病。而没得病的人去做检查,有 1% 的人是假阳性,99% 的人是真阴性,也就是会有 1% 的几率被误诊为有病。 于是,他认为,既然误检的概率这么低,那么他确实患病的概率应该是非常高的。\n",
118177 "\n",
119178 "可是,医生却告诉他,他被感染的概率只有 0.09 左右。这是怎么回事呢?\n",
120179 "\n",
155214 "cell_type": "markdown",
156215 "metadata": {},
157216 "source": [
158 "## 2.4.3 朴素贝叶斯分类器 \n",
217 "<br>"
218 ]
219 },
220 {
221 "cell_type": "markdown",
222 "metadata": {},
223 "source": [
224 "## 2.4.3 朴素贝叶斯分类器 "
225 ]
226 },
227 {
228 "cell_type": "markdown",
229 "metadata": {},
230 "source": [
159231 "一种常用的分类算法,其假设**样本各个特征之间相互独立、互不影响**。"
160232 ]
161233 },
163235 "cell_type": "markdown",
164236 "metadata": {},
165237 "source": [
166 "小例子:预测同学会不会在某店铺订餐。\n",
167 "\n",
168 "**目标**:根据某同学的订单记录,如果向他推荐一家“价位低、口味偏甜、距离远”的店铺,判断他会下单吗?\n",
169 "\n",
238 "<center><video src=\"http://files.momodel.cn/bayes_naive.mp4\" controls=\"controls\" width=800px></center>"
239 ]
240 },
241 {
242 "cell_type": "markdown",
243 "metadata": {},
244 "source": [
245 "<br>"
246 ]
247 },
248 {
249 "cell_type": "markdown",
250 "metadata": {},
251 "source": [
170252 "**数据**:该同学的下单记录如下\n",
171253 "\n",
172254 "|店铺价位|店铺口味|店铺距离|是否下单|\n",
178260 "|低|偏甜|近|是|\n",
179261 "|低|偏甜|近|是|\n",
180262 "|低|清淡|远|否|\n",
181 "|低|偏辣|远|是|\n"
263 "|低|偏辣|远|是|\n",
264 "\n",
265 "**目标**:根据某同学的订单记录,如果向他推荐一家“价位低、口味偏甜、距离远”的店铺,判断他会下单吗?"
182266 ]
183267 },
184268 {
259343 "$$\n",
260344 "\n",
261345 "上述两个计算公式分母相同,对计算结果不影响,因此就从计算过程中略去了。"
346 ]
347 },
348 {
349 "cell_type": "markdown",
350 "metadata": {},
351 "source": [
352 "<br>"
262353 ]
263354 },
264355 {
360451 "cell_type": "markdown",
361452 "metadata": {},
362453 "source": [
454 "<br>"
455 ]
456 },
457 {
458 "cell_type": "markdown",
459 "metadata": {},
460 "source": [
363461 "3.根据 **MNIST** 训练集训练朴素贝叶斯分类器"
364462 ]
365463 },
445543 "\n",
446544 " print(\"数字 %s 的样本个数:%4s,预测正确的个数:%4s,准确率:%.4s%%\" % (\n",
447545 " i, class_num[i], predict_num[i][i], class_accuracy[i]))"
546 ]
547 },
548 {
549 "cell_type": "markdown",
550 "metadata": {},
551 "source": [
552 "<br>"
448553 ]
449554 },
450555 {
520625 "cell_type": "markdown",
521626 "metadata": {},
522627 "source": [
523 "## 扩展阅读\n",
628 "## 扩展阅读"
629 ]
630 },
631 {
632 "cell_type": "markdown",
633 "metadata": {},
634 "source": [
524635 "1. [一文读懂概率论学习:贝叶斯理论](https://www.jiqizhixin.com/articles/2019-11-21)\n",
525636 "2. [朴素贝叶斯法讲解](https://www.bilibili.com/video/av57126177?from=search&seid=1588787263892359481)\n",
526637 "3. [sklearn 贝叶斯方法](https://scikit-learn.org/stable/modules/naive_bayes.html)"
1818 "metadata": {},
1919 "source": [
2020 "## 2.5.1 人脑神经机制"
21 ]
22 },
23 {
24 "cell_type": "markdown",
25 "metadata": {},
26 "source": [
27 "<center><video src=\"http://files.momodel.cn/nn_humanbrain.mp4\" controls=\"controls\" width=800px></center>"
2128 ]
2229 },
2330 {
4451 "cell_type": "markdown",
4552 "metadata": {},
4653 "source": [
54 "人工智能中神经网络正是体现“逐层抽象、渐进学习”机制的学习模型。"
55 ]
56 },
57 {
58 "cell_type": "markdown",
59 "metadata": {},
60 "source": [
61 "<img src=\"http://imgbed.momodel.cn/微信图片_20200114133755.png\"/>"
62 ]
63 },
64 {
65 "cell_type": "markdown",
66 "metadata": {},
67 "source": [
68 "人眼在辨识图片时,会先提取边缘特征,再识别部件,最后再得到最高层的模式。"
69 ]
70 },
71 {
72 "cell_type": "markdown",
73 "metadata": {},
74 "source": [
75 "<img src=\"http://imgbed.momodel.cn//20200103102429.png\" width=500>"
76 ]
77 },
78 {
79 "cell_type": "markdown",
80 "metadata": {},
81 "source": [
82 "## 2.5.2 感知机模型"
83 ]
84 },
85 {
86 "cell_type": "markdown",
87 "metadata": {},
88 "source": [
89 "<center><video src=\"http://files.momodel.cn/nn_perceptron.mp4\" controls=\"controls\" width=800px></center>"
90 ]
91 },
92 {
93 "cell_type": "markdown",
94 "metadata": {},
95 "source": [
96 "**感知机模型**:\n",
97 "\n",
4798 "<table>\n",
4899 " <tr>\n",
49 " <td ><center><img src=\"http://imgbed.momodel.cn/微信图片_20200114133800.png\" width=600px/></center></td>\n",
50 " <td><center><img src=\"http://imgbed.momodel.cn/微信图片_20200114133808.png\" width=600px/></center></td>\n",
51 " </tr>\n",
52 "</table>"
53 ]
54 },
55 {
56 "cell_type": "markdown",
57 "metadata": {},
58 "source": [
59 "人工智能中神经网络正是体现“逐层抽象、渐进学习”机制的学习模型。"
60 ]
61 },
62 {
63 "cell_type": "markdown",
64 "metadata": {},
65 "source": [
66 "<img src=\"http://imgbed.momodel.cn/微信图片_20200114133755.png\"/>"
67 ]
68 },
69 {
70 "cell_type": "markdown",
71 "metadata": {},
72 "source": [
73 "人眼在辨识图片时,会先提取边缘特征,再识别部件,最后再得到最高层的模式。"
74 ]
75 },
76 {
77 "cell_type": "markdown",
78 "metadata": {},
79 "source": [
80 "<img src=\"http://imgbed.momodel.cn//20200103102429.png\" width=500>"
81 ]
82 },
83 {
84 "cell_type": "markdown",
85 "metadata": {},
86 "source": [
87 "## 2.5.2 感知机模型"
88 ]
89 },
90 {
91 "cell_type": "markdown",
92 "metadata": {},
93 "source": [
94 "**感知机模型**:\n",
95 "\n",
96 "<table>\n",
97 " <tr>\n",
98 " <td ><center><img src=\"http://imgbed.momodel.cn/感知器模型.png\" width=400/></center></td>\n",
99 " <td><center><img src=\"http://imgbed.momodel.cn/微信图片_20200114135559.png\" width=400/>\n",
100 " <img src=\"http://imgbed.momodel.cn/微信图片_20200114135643.png\" width=400/></center></td>\n",
100 " <td ><center><img src=\"http://imgbed.momodel.cn/感知器模型.png\" width=300/></center></td>\n",
101 " <td><center><img src=\"http://imgbed.momodel.cn//20200208141322.png\" width=400></center></td>\n",
101102 " </tr>\n",
102103 "</table>"
103104 ]
222223 "cell_type": "markdown",
223224 "metadata": {},
224225 "source": [
225 "我们根据上面的定义可以编写一个简单的感知机模型。"
226 "**动手练**"
227 ]
228 },
229 {
230 "cell_type": "markdown",
231 "metadata": {},
232 "source": [
233 "编写一个简单的感知机模型。"
226234 ]
227235 },
228236 {
262270 "cell_type": "markdown",
263271 "metadata": {},
264272 "source": [
273 "<center><video src=\"http://files.momodel.cn/nn_fullconnect.mp4\" controls=\"controls\" width=800px></center>"
274 ]
275 },
276 {
277 "cell_type": "markdown",
278 "metadata": {},
279 "source": [
280 "神经网络架构示意图如下:\n",
281 "\n",
265282 "<img src=\"http://imgbed.momodel.cn//20200103111837.png\" width=400>\n",
266283 "\n",
267284 "与感知机的不同,神经网络:\n",
268285 "+ 输入层和输出层之间存在若干隐藏层。\n",
269286 "+ 每个隐藏层中包含若干神经元。\n",
270287 "\n",
271 "通过下面几个小视频,你会更好的理解神经网络:\n",
272 "<center><video src=\"./nn_media1.mp4\" controls=\"controls\" width=800px></center>\n",
273 "<center><video src=\"./nn_media2.mp4\" controls=\"controls\" width=800px></center>\n",
274 "<center><video src=\"./nn_media3.mp4\" controls=\"controls\" width=800px></center>\n",
275 "<center><video src=\"./nn_media4.mp4\" controls=\"controls\" width=800px></center>\n",
276 "<center><video src=\"./nn_media5.mp4\" controls=\"controls\" width=800px></center>\n",
277 "<center><video src=\"./nn_media6.mp4\" controls=\"controls\" width=800px></center>\n"
288 "通过下面几个小视频,你会更好的理解神经网络:"
289 ]
290 },
291 {
292 "cell_type": "markdown",
293 "metadata": {},
294 "source": [
295 "<center><video src=\"http://files.momodel.cn/nn_media1.mp4\" controls=\"controls\" width=800px></center>"
296 ]
297 },
298 {
299 "cell_type": "markdown",
300 "metadata": {},
301 "source": [
302 "<center><video src=\"http://files.momodel.cn/nn_media2.mp4\" controls=\"controls\" width=800px></center>"
303 ]
304 },
305 {
306 "cell_type": "markdown",
307 "metadata": {},
308 "source": [
309 "<br>"
310 ]
311 },
312 {
313 "cell_type": "markdown",
314 "metadata": {},
315 "source": [
316 "<center><video src=\"http://files.momodel.cn/nn_media3.mp4\" controls=\"controls\" width=800px></center>"
317 ]
318 },
319 {
320 "cell_type": "markdown",
321 "metadata": {},
322 "source": [
323 "<br>"
324 ]
325 },
326 {
327 "cell_type": "markdown",
328 "metadata": {},
329 "source": [
330 "<center><video src=\"http://files.momodel.cn/nn_media4.mp4\" controls=\"controls\" width=800px></center>"
331 ]
332 },
333 {
334 "cell_type": "markdown",
335 "metadata": {},
336 "source": [
337 "<br>"
338 ]
339 },
340 {
341 "cell_type": "markdown",
342 "metadata": {},
343 "source": [
344 "<center><video src=\"http://files.momodel.cn/nn_media5.mp4\" controls=\"controls\" width=800px></center>"
345 ]
346 },
347 {
348 "cell_type": "markdown",
349 "metadata": {},
350 "source": [
351 "<br>"
352 ]
353 },
354 {
355 "cell_type": "markdown",
356 "metadata": {},
357 "source": [
358 "<center><video src=\"http://files.momodel.cn/nn_media6.mp4\" controls=\"controls\" width=800px></center>"
359 ]
360 },
361 {
362 "cell_type": "markdown",
363 "metadata": {},
364 "source": [
365 "### 扩展内容\n",
366 "\n",
367 "**卷积神经网络**"
368 ]
369 },
370 {
371 "cell_type": "markdown",
372 "metadata": {},
373 "source": [
374 "<center><video src=\"http://files.momodel.cn/nn_conv.mp4\" controls=\"controls\" width=800px></center>"
278375 ]
279376 },
280377 {
366463 " \"\"\"\n",
367464 " # 选择模型,选择序贯模型(Sequential())\n",
368465 " model = Sequential()\n",
369 " \n",
370466 " # 添加全连接层,共 512 个神经元\n",
371467 " model.add(Dense(512,input_shape=(784,),kernel_initializer='he_normal'))\n",
372 " \n",
373468 " # 添加激活层,激活函数选择 relu \n",
374469 " model.add(Activation('relu'))\n",
375 " \n",
376470 " # 添加全连接层,共 512 个神经元\n",
377471 " model.add(Dense(512,kernel_initializer='he_normal'))\n",
378 " \n",
379472 " # 添加激活层,激活函数选择 relu \n",
380473 " model.add(Activation('relu'))\n",
381 " \n",
382474 " # 添加全连接层,共 10 个神经元\n",
383475 " model.add(Dense(nb_classes))\n",
384 " \n",
385476 " # 添加激活层,激活函数选择 softmax\n",
386477 " model.add(Activation('softmax'))\n",
387 " \n",
388478 " return model\n",
389479 "\n",
390480 "model = create_model()"
412502 " \"\"\"\n",
413503 " # 编译模型\n",
414504 " model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])\n",
415 " \n",
416505 " # 模型训练\n",
417506 " model.fit(X_train, y_train, epochs=5, batch_size=64, verbose=1, validation_split=0.05)\n",
418 " \n",
419507 " # 保存模型\n",
420508 " model.save(model_path)\n",
421 " \n",
422509 " # 模型评估,获取测试集的损失值和准确率\n",
423510 " loss, accuracy = model.evaluate(X_test, y_test)\n",
424 "\n",
425511 " # 打印结果\n",
426512 " print('Test loss:', loss)\n",
427513 " print(\"Accuracy:\", accuracy)\n",
555641 "2. [利用 pix2pix 将你的草图变成图片](https://momodel.cn/explore/5c0cb5df1afd945819064752?type=app)\n",
556642 "3. [自动生成图片描述](https://momodel.cn/explore/5ba33f578fe30b412042ac08?&type=app&tab=1)"
557643 ]
558 },
559 {
560 "cell_type": "code",
561 "execution_count": null,
562 "metadata": {},
563 "outputs": [],
564 "source": []
565644 }
566645 ],
567646 "metadata": {
1818 "metadata": {},
1919 "source": [
2020 "## 2.5.1 人脑神经机制"
21 ]
22 },
23 {
24 "cell_type": "markdown",
25 "metadata": {},
26 "source": [
27 "<center><video src=\"http://files.momodel.cn/nn_humanbrain.mp4\" controls=\"controls\" width=800px></center>"
28 ]
29 },
30 {
31 "cell_type": "markdown",
32 "metadata": {},
33 "source": [
34 "<br>"
2135 ]
2236 },
2337 {
3751 "cell_type": "markdown",
3852 "metadata": {},
3953 "source": [
54 "高层的特征是低层特征的组合,从低层到高层的特征表示越来越抽象,越来越能表现语义。\n"
55 ]
56 },
57 {
58 "cell_type": "markdown",
59 "metadata": {},
60 "source": [
61 "人工智能中神经网络正是体现“逐层抽象、渐进学习”机制的学习模型。"
62 ]
63 },
64 {
65 "cell_type": "markdown",
66 "metadata": {},
67 "source": [
68 "<img src=\"http://imgbed.momodel.cn/微信图片_20200114133755.png\"/>\n"
69 ]
70 },
71 {
72 "cell_type": "markdown",
73 "metadata": {},
74 "source": [
75 "<br>"
76 ]
77 },
78 {
79 "cell_type": "markdown",
80 "metadata": {},
81 "source": [
82 "人眼在辨识图片时,会先提取边缘特征,再识别部件,最后再得到最高层的模式。"
83 ]
84 },
85 {
86 "cell_type": "markdown",
87 "metadata": {},
88 "source": [
89 "<img src=\"http://imgbed.momodel.cn//20200103102429.png\" width=500>"
90 ]
91 },
92 {
93 "cell_type": "markdown",
94 "metadata": {},
95 "source": [
96 "<br>"
97 ]
98 },
99 {
100 "cell_type": "markdown",
101 "metadata": {},
102 "source": [
103 "## 2.5.2 感知机模型"
104 ]
105 },
106 {
107 "cell_type": "markdown",
108 "metadata": {},
109 "source": [
110 "<center><video src=\"http://files.momodel.cn/nn_perceptron.mp4\" controls=\"controls\" width=800px></center>"
111 ]
112 },
113 {
114 "cell_type": "markdown",
115 "metadata": {},
116 "source": [
117 "<br>"
118 ]
119 },
120 {
121 "cell_type": "markdown",
122 "metadata": {},
123 "source": [
124 "**感知机模型**:\n",
125 "\n",
40126 "<table>\n",
41127 " <tr>\n",
42 " <td ><center><img src=\"http://imgbed.momodel.cn/微信图片_20200114133800.png\" width=600px/></center></td>\n",
43 " <td><center><img src=\"http://imgbed.momodel.cn/微信图片_20200114133808.png\" width=600px/></center></td>\n",
44 " </tr>\n",
45 "</table>"
46 ]
47 },
48 {
49 "cell_type": "markdown",
50 "metadata": {},
51 "source": [
52 "高层的特征是低层特征的组合,从低层到高层的特征表示越来越抽象,越来越能表现语义。\n"
53 ]
54 },
55 {
56 "cell_type": "markdown",
57 "metadata": {},
58 "source": [
59 "人工智能中神经网络正是体现“逐层抽象、渐进学习”机制的学习模型。"
60 ]
61 },
62 {
63 "cell_type": "markdown",
64 "metadata": {},
65 "source": [
66 "<img src=\"http://imgbed.momodel.cn/微信图片_20200114133755.png\"/>\n"
67 ]
68 },
69 {
70 "cell_type": "markdown",
71 "metadata": {},
72 "source": [
73 "人眼在辨识图片时,会先提取边缘特征,再识别部件,最后再得到最高层的模式。"
74 ]
75 },
76 {
77 "cell_type": "markdown",
78 "metadata": {},
79 "source": [
80 "<img src=\"http://imgbed.momodel.cn//20200103102429.png\" width=500>"
81 ]
82 },
83 {
84 "cell_type": "markdown",
85 "metadata": {},
86 "source": [
87 "## 2.5.2 感知机模型"
88 ]
89 },
90 {
91 "cell_type": "markdown",
92 "metadata": {},
93 "source": [
94 "**感知机模型**:\n",
95 "\n",
96 "<table>\n",
97 " <tr>\n",
98 " <td ><center><img src=\"http://imgbed.momodel.cn/感知器模型.png\" width=400/></center></td>\n",
99 " <td><center><img src=\"http://imgbed.momodel.cn/微信图片_20200114135559.png\" width=400/>\n",
100 " <img src=\"http://imgbed.momodel.cn/微信图片_20200114135643.png\" width=400/></center></td>\n",
128 " <td ><img src=\"http://imgbed.momodel.cn/感知器模型.png\" width=300/></td>\n",
129 " <td><img src=\"http://imgbed.momodel.cn//20200208141322.png\" width=400></td>\n",
101130 " </tr>\n",
102131 "</table>"
103132 ]
121150 "cell_type": "markdown",
122151 "metadata": {},
123152 "source": [
124 "在输出的判断上,其实不仅可以简单的按照阈值来判断,可以通过一个函数来进行计算,这个函数称为激活函数。常见的激活函数有: sigmoid,tanh,relu 等。下面我们看看这些激活函数的曲线图。"
153 "<br>"
154 ]
155 },
156 {
157 "cell_type": "markdown",
158 "metadata": {},
159 "source": [
160 "在输出的判断上,其实不仅可以简单的按照阈值来判断,可以通过一个函数来进行计算,这个函数称为**激活函数**。\n",
161 "\n",
162 "常见的激活函数有: sigmoid,tanh,relu 等。\n",
163 "\n",
164 "下面我们看看这些激活函数的曲线图。"
125165 ]
126166 },
127167 {
162202 ]
163203 },
164204 {
205 "cell_type": "markdown",
206 "metadata": {},
207 "source": [
208 "<br>"
209 ]
210 },
211 {
165212 "cell_type": "code",
166213 "execution_count": null,
167214 "metadata": {},
215262 "\n",
216263 "# 绘制 relu 函数\n",
217264 "plot_activation_function(relu)"
265 ]
266 },
267 {
268 "cell_type": "markdown",
269 "metadata": {},
270 "source": [
271 "<br>"
218272 ]
219273 },
220274 {
257311 "cell_type": "markdown",
258312 "metadata": {},
259313 "source": [
314 "<br>"
315 ]
316 },
317 {
318 "cell_type": "markdown",
319 "metadata": {},
320 "source": [
260321 "## 2.5.3 神经网络"
261322 ]
262323 },
264325 "cell_type": "markdown",
265326 "metadata": {},
266327 "source": [
328 "<center><video src=\"http://files.momodel.cn/nn_fullconnect.mp4\" controls=\"controls\" width=800px></center>"
329 ]
330 },
331 {
332 "cell_type": "markdown",
333 "metadata": {},
334 "source": [
335 "<br>"
336 ]
337 },
338 {
339 "cell_type": "markdown",
340 "metadata": {},
341 "source": [
342 "神经网络架构示意图如下:\n",
343 "\n",
267344 "<img src=\"http://imgbed.momodel.cn//20200103111837.png\" width=400>\n",
268345 "\n",
269346 "与感知机的不同,神经网络:\n",
270347 "+ 输入层和输出层之间存在若干隐藏层。\n",
271348 "+ 每个隐藏层中包含若干神经元。\n",
272349 "\n",
273 "通过下面几个小视频,你会更好的理解神经网络:\n",
274 "<center><video src=\"./nn_media1.mp4\" controls=\"controls\" width=800px></center>\n",
275 "<center><video src=\"./nn_media2.mp4\" controls=\"controls\" width=800px></center>\n",
276 "<center><video src=\"./nn_media3.mp4\" controls=\"controls\" width=800px></center>\n",
277 "<center><video src=\"./nn_media4.mp4\" controls=\"controls\" width=800px></center>\n",
278 "<center><video src=\"./nn_media5.mp4\" controls=\"controls\" width=800px></center>\n",
279 "<center><video src=\"./nn_media6.mp4\" controls=\"controls\" width=800px></center>"
350 "通过下面几个小视频,你会更好的理解神经网络:"
351 ]
352 },
353 {
354 "cell_type": "markdown",
355 "metadata": {},
356 "source": [
357 "<center><video src=\"http://files.momodel.cn/nn_media1.mp4\" controls=\"controls\" width=800px></center>"
358 ]
359 },
360 {
361 "cell_type": "markdown",
362 "metadata": {},
363 "source": [
364 "<br>"
365 ]
366 },
367 {
368 "cell_type": "markdown",
369 "metadata": {},
370 "source": [
371 "<center><video src=\"http://files.momodel.cn/nn_media2.mp4\" controls=\"controls\" width=800px></center>"
372 ]
373 },
374 {
375 "cell_type": "markdown",
376 "metadata": {},
377 "source": [
378 "<br>"
379 ]
380 },
381 {
382 "cell_type": "markdown",
383 "metadata": {},
384 "source": [
385 "<center><video src=\"http://files.momodel.cn/nn_media3.mp4\" controls=\"controls\" width=800px></center>"
386 ]
387 },
388 {
389 "cell_type": "markdown",
390 "metadata": {},
391 "source": [
392 "<br>"
393 ]
394 },
395 {
396 "cell_type": "markdown",
397 "metadata": {},
398 "source": [
399 "<center><video src=\"http://files.momodel.cn/nn_media4.mp4\" controls=\"controls\" width=800px></center>"
400 ]
401 },
402 {
403 "cell_type": "markdown",
404 "metadata": {},
405 "source": [
406 "<br>"
407 ]
408 },
409 {
410 "cell_type": "markdown",
411 "metadata": {},
412 "source": [
413 "<center><video src=\"http://files.momodel.cn/nn_media5.mp4\" controls=\"controls\" width=800px></center>"
414 ]
415 },
416 {
417 "cell_type": "markdown",
418 "metadata": {},
419 "source": [
420 "<br>"
421 ]
422 },
423 {
424 "cell_type": "markdown",
425 "metadata": {},
426 "source": [
427 "<center><video src=\"http://files.momodel.cn/nn_media6.mp4\" controls=\"controls\" width=800px></center>"
428 ]
429 },
430 {
431 "cell_type": "markdown",
432 "metadata": {},
433 "source": [
434 "<br>"
435 ]
436 },
437 {
438 "cell_type": "markdown",
439 "metadata": {},
440 "source": [
441 "### 扩展内容\n",
442 "\n",
443 "**卷积神经网络**"
444 ]
445 },
446 {
447 "cell_type": "markdown",
448 "metadata": {},
449 "source": [
450 "<center><video src=\"http://files.momodel.cn/nn_conv.mp4\" controls=\"controls\" width=800px></center>"
451 ]
452 },
453 {
454 "cell_type": "markdown",
455 "metadata": {},
456 "source": [
457 "<br>"
280458 ]
281459 },
282460 {
352530 "cell_type": "markdown",
353531 "metadata": {},
354532 "source": [
533 "<br>"
534 ]
535 },
536 {
537 "cell_type": "markdown",
538 "metadata": {},
539 "source": [
355540 "3. 搭建神经网络模型"
356541 ]
357542 },
368553 " \"\"\"\n",
369554 " # 选择模型,选择序贯模型(Sequential())\n",
370555 " model = Sequential()\n",
371 " \n",
372556 " # 添加全连接层,共 512 个神经元\n",
373557 " model.add(Dense(512,input_shape=(784,),kernel_initializer='he_normal'))\n",
374 " \n",
375558 " # 添加激活层,激活函数选择 relu \n",
376559 " model.add(Activation('relu'))\n",
377 " \n",
378560 " # 添加全连接层,共 512 个神经元\n",
379561 " model.add(Dense(512,kernel_initializer='he_normal'))\n",
380 " \n",
381562 " # 添加激活层,激活函数选择 relu \n",
382563 " model.add(Activation('relu'))\n",
383 " \n",
384564 " # 添加全连接层,共 10 个神经元\n",
385565 " model.add(Dense(nb_classes))\n",
386 " \n",
387566 " # 添加激活层,激活函数选择 softmax\n",
388567 " model.add(Activation('softmax'))\n",
389 " \n",
390568 " return model\n",
391569 "\n",
392570 "model = create_model()"
571 ]
572 },
573 {
574 "cell_type": "markdown",
575 "metadata": {},
576 "source": [
577 "<br>"
393578 ]
394579 },
395580 {
414599 " \"\"\"\n",
415600 " # 编译模型\n",
416601 " model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])\n",
417 " \n",
418602 " # 模型训练\n",
419603 " model.fit(X_train, y_train, epochs=5, batch_size=64, verbose=1, validation_split=0.05)\n",
420 " \n",
421604 " # 保存模型\n",
422605 " model.save(model_path)\n",
423 " \n",
424606 " # 模型评估,获取测试集的损失值和准确率\n",
425607 " loss, accuracy = model.evaluate(X_test, y_test)\n",
426 "\n",
427608 " # 打印结果\n",
428609 " print('Test loss:', loss)\n",
429610 " print(\"Accuracy:\", accuracy)\n",
430 "\n",
611 " \n",
431612 "# 训练模型和评估模型\n",
432613 "fit_and_predict(model, model_path='./model.h5')"
433614 ]
436617 "cell_type": "markdown",
437618 "metadata": {},
438619 "source": [
620 "<br>"
621 ]
622 },
623 {
624 "cell_type": "markdown",
625 "metadata": {},
626 "source": [
439627 "### 实践与体验\n",
440628 "#### 调节神经网络结构和参数\n",
441629 "\n",
455643 " \"\"\"\n",
456644 " # 选择模型,选择序贯模型(Sequential())\n",
457645 " model = Sequential()\n",
458 "\n",
459646 " # 添加全连接层,共 512 个神经元\n",
460647 " model.add(Dense(512, input_shape=(784,), kernel_initializer='he_normal'))\n",
461 "\n",
462648 " # 添加激活层,激活函数选择 relu\n",
463649 " model.add(Activation('relu'))\n",
464 "\n",
465650 " # 添加全连接层,共 10 个神经元\n",
466651 " model.add(Dense(nb_classes))\n",
467 "\n",
468652 " # 添加激活层,激活函数选择 softmax\n",
469653 " model.add(Activation('softmax'))\n",
470 "\n",
471654 " return model\n",
472655 "\n",
473656 "# 搭建神经网络\n",
474657 "model1 = create_model1()\n",
475 "\n",
476658 "# 训练神经网络模型,保存模型和评估模型\n",
477659 "fit_and_predict(model1, model_path='./model1.h5')"
478660 ]
481663 "cell_type": "markdown",
482664 "metadata": {},
483665 "source": [
666 "<br>"
667 ]
668 },
669 {
670 "cell_type": "markdown",
671 "metadata": {},
672 "source": [
484673 "2. 修改两层隐藏层神经元的数量,然后训练模型得出准确率。"
485674 ]
486675 },
497686 " \"\"\"\n",
498687 " # 选择模型,选择序贯模型(Sequential())\n",
499688 " model = Sequential()\n",
500 "\n",
501689 " # 添加全连接层,共 256 个神经元\n",
502690 " model.add(Dense(256, input_shape=(784,), kernel_initializer='he_normal'))\n",
503 "\n",
504691 " # 添加激活层,激活函数选择 relu\n",
505692 " model.add(Activation('relu'))\n",
506 "\n",
507693 " # 添加全连接层,共 256 个神经元\n",
508694 " model.add(Dense(256, kernel_initializer='he_normal'))\n",
509 "\n",
510695 " # 添加激活层,激活函数选择 relu\n",
511696 " model.add(Activation('relu'))\n",
512 "\n",
513697 " # 添加全连接层,共 10 个神经元\n",
514698 " model.add(Dense(nb_classes))\n",
515 "\n",
516699 " # 添加激活层,激活函数选择 softmax\n",
517700 " model.add(Activation('softmax'))\n",
518 "\n",
519701 " return model\n",
520702 "\n",
521703 "# 搭建神经网络模型\n",
522704 "model2 = create_model2()\n",
523 "\n",
524705 "# 训练神经网络模型,保存模型并评估模型\n",
525 "fit_and_predict(model2,model_path='./model2.h5')\n",
526 "\n"
706 "fit_and_predict(model2,model_path='./model2.h5')\n"
707 ]
708 },
709 {
710 "cell_type": "markdown",
711 "metadata": {},
712 "source": [
713 "<br>"
527714 ]
528715 },
529716 {
570757 "cell_type": "markdown",
571758 "metadata": {},
572759 "source": [
760 "<br>"
761 ]
762 },
763 {
764 "cell_type": "markdown",
765 "metadata": {},
766 "source": [
573767 "## 扩展阅读\n",
574768 "1. [利用starGan算法改变人物的面部特征](https://momodel.cn/explore/5c0cc4591afd945c5177fb51?type=app)\n",
575769 "2. [利用 pix2pix 将你的草图变成图片](https://momodel.cn/explore/5c0cb5df1afd945819064752?type=app)\n",
576770 "3. [自动生成图片描述](https://momodel.cn/explore/5ba33f578fe30b412042ac08?&type=app&tab=1)"
577771 ]
578 },
579 {
580 "cell_type": "code",
581 "execution_count": null,
582 "metadata": {},
583 "outputs": [],
584 "source": []
585772 }
586773 ],
587774 "metadata": {
0 {
1 "cells": [
2 {
3 "cell_type": "code",
4 "execution_count": 1,
5 "metadata": {},
6 "outputs": [],
7 "source": [
8 "import networkx as nx"
9 ]
10 },
11 {
12 "cell_type": "code",
13 "execution_count": 9,
14 "metadata": {},
15 "outputs": [],
16 "source": [
17 "def bfs_search(G, max_depth, start_node, target_node):\n",
18 " # 待访问的路径\n",
19 " to_search = [(start_node, 0)]\n",
20 " # 存储所有的历史路径,及此路径的距离\n",
21 " bfs_path = []\n",
22 " # 正确的路径列表,及此路径的距离\n",
23 " bfs_correct_path = []\n",
24 " # 当还有待访问的路径时\n",
25 " while to_search:\n",
26 " # 从待访问的路径中取第一个待访问路径及其路径长度,例如 AC\n",
27 " this_path, this_path_dis = to_search.pop(0)\n",
28 " # 如果待访问的路径达到最大搜索深度,跳出循环\n",
29 " if len(this_path) > max_depth :\n",
30 " break\n",
31 " # 把刚取出的路径存入历史路径中\n",
32 " bfs_path.append((this_path, this_path_dis))\n",
33 " # 如果路径的最后一个节点是目标节点,路径 AC 的最后一个节点是 C\n",
34 " if this_path[-1] == target_node:\n",
35 " # 其为一条正确的路径,将存入正确的路径列表中,\n",
36 " # 并不再继续往其子节点进行探索\n",
37 " bfs_correct_path.append((this_path, this_path_dis))\n",
38 " continue\n",
39 " # 找到路径最后一个节点的相邻节点\n",
40 " for ne in sorted(G[this_path[-1]]):\n",
41 " # 如果相邻节点不在路径中,即不存在回路\n",
42 " if ne not in this_path:\n",
43 " # 则加入到待访问的路径中\n",
44 " to_search.append((this_path + ne,\n",
45 " this_path_dis + G[this_path[-1]][ne][\n",
46 " 'weight']))\n",
47 " return bfs_path, bfs_correct_path"
48 ]
49 },
50 {
51 "cell_type": "code",
52 "execution_count": 10,
53 "metadata": {},
54 "outputs": [],
55 "source": [
56 "# 定义节点列表\n",
57 "node_list = ['A', 'B', 'C', 'D', 'E', 'F', 'G']\n",
58 "\n",
59 "# 定义边及权重列表\n",
60 "weighted_edges_list = [('A', 'B', 8), ('A', 'C', 20),\n",
61 " ('B', 'F', 40), ('B', 'E', 30),\n",
62 " ('B', 'D', 20), ('C', 'D', 10), \n",
63 " ('D', 'G', 10), ('D', 'E', 10),\n",
64 " ('E', 'F', 30), ('F', 'G', 30)]\n",
65 "\n",
66 "# 定义绘图中各个节点的坐标\n",
67 "nodes_pos = {\"A\": (1, 1), \"B\": (3, 3), \"C\": (5, 0), \"D\": (9, 2),\n",
68 " \"E\": (7, 4), \"F\": (6,6),\"G\": (11,5)}\n",
69 "\n",
70 "G = nx.Graph()\n",
71 "G.add_nodes_from(node_list)\n",
72 "G.add_weighted_edges_from(weighted_edges_list)\n",
73 "\n"
74 ]
75 },
76 {
77 "cell_type": "code",
78 "execution_count": 13,
79 "metadata": {},
80 "outputs": [],
81 "source": [
82 "bfs_path, bfs_correct_path = bfs_search(G, 3, 'A', 'G')"
83 ]
84 },
85 {
86 "cell_type": "code",
87 "execution_count": 14,
88 "metadata": {},
89 "outputs": [],
90 "source": [
91 "paths = [e[0] for e in bfs_path]"
92 ]
93 },
94 {
95 "cell_type": "code",
96 "execution_count": 23,
97 "metadata": {},
98 "outputs": [],
99 "source": [
100 "import collections\n",
101 "def get_search_tree_node_position(paths):\n",
102 " \"\"\"得到绘图时各个节点的坐标\n",
103 " \"\"\"\n",
104 " max_depth = 3 \n",
105 " # 得到每条路径的子路径\n",
106 " path_childern = {}\n",
107 " for path in paths:\n",
108 " father = path[:-1]\n",
109 " if father in paths:\n",
110 " if father in path_childern:\n",
111 " path_childern[father].append(path)\n",
112 " else:\n",
113 " path_childern[father] = [path]\n",
114 " # 对每条子路径排序\n",
115 " o_path_childern = collections.OrderedDict(\n",
116 " sorted(path_childern.items()))\n",
117 " # 计算每个树图中每个节点的位置\n",
118 " tree_node_position = {paths[0][0]: (1, 0, 2)}\n",
119 " for path, sub_paths in o_path_childern.items():\n",
120 " y_pos = -1.0 / max_depth * len(path)\n",
121 " dx = tree_node_position[path][2] / len(sub_paths)\n",
122 " sub_paths.sort()\n",
123 " for index, e_s in enumerate(sub_paths):\n",
124 " x_pos = tree_node_position[path][0] - tree_node_position[path][\n",
125 " 2] / 2 + dx / 2 + dx * index\n",
126 " tree_node_position[e_s] = (x_pos, y_pos, dx)\n",
127 " print(tree_node_position)"
128 ]
129 },
130 {
131 "cell_type": "code",
132 "execution_count": 24,
133 "metadata": {},
134 "outputs": [
135 {
136 "name": "stdout",
137 "output_type": "stream",
138 "text": [
139 "{'A': (1, 0, 2), 'ACD': (1.5, -0.6666666666666666, 1.0), 'AC': (1.5, -0.3333333333333333, 1.0), 'ABD': (0.16666666666666666, -0.6666666666666666, 0.3333333333333333), 'AB': (0.5, -0.3333333333333333, 1.0), 'ABF': (0.8333333333333333, -0.6666666666666666, 0.3333333333333333), 'ABE': (0.5, -0.6666666666666666, 0.3333333333333333)}\n"
140 ]
141 }
142 ],
143 "source": [
144 "get_search_tree_node_position(paths)"
145 ]
146 },
147 {
148 "cell_type": "code",
149 "execution_count": null,
150 "metadata": {},
151 "outputs": [],
152 "source": [
153 "import matplotlib.pyplot as plt\n",
154 "import collections\n",
155 "from IPython import display\n",
156 "import networkx as nx\n",
157 "import numpy as np\n",
158 "import time\n",
159 "\n",
160 "\n",
161 "class SearchGraph():\n",
162 " def __init__(self,\n",
163 " node_list, \n",
164 " weighted_edges_list, \n",
165 " start_node,\n",
166 " target_node,\n",
167 " max_depth=1000,\n",
168 " nodes_pos=None,\n",
169 " help_info=None,):\n",
170 " # 图中的节点\n",
171 " self.node_list = node_list\n",
172 " self.weighted_edges_list = weighted_edges_list\n",
173 " self.start_node = start_node\n",
174 " self.target_node = target_node\n",
175 " self.nodes_pos = nodes_pos\n",
176 " self.max_depth = min(max_depth, len(node_list))\n",
177 " self.temp_best_path = None\n",
178 " \n",
179 " self.weighted_edges_dic = {frozenset([e[0],e[1]]):e[2] for e in weighted_edges_list}\n",
180 " self.help_info = help_info\n",
181 " self.path_score={self.start_node:0}\n",
182 " \n",
183 " self.animation_type = 'dfs'\n",
184 " \n",
185 " self.basic_node_color = '#6CB6FF'\n",
186 " self.start_node_color = 'y'\n",
187 " self.target_node_color = 'r'\n",
188 " self.visited_node_color = 'g'\n",
189 " \n",
190 " self.basic_edge_color = 'b'\n",
191 " self.visited_edge_color = 'g'\n",
192 " \n",
193 " self.success_color = 'r'\n",
194 " \n",
195 " self.correct_paths={}\n",
196 " self.show_correct_path = []\n",
197 " self.build_graph()\n",
198 " self.get_search_tree_node_position()\n",
199 " self.bfs_search()\n",
200 " \n",
201 " \n",
202 "\n",
203 " def build_graph(self):\n",
204 " self.G = nx.Graph()\n",
205 " self.G.add_nodes_from(self.node_list)\n",
206 " self.G.add_weighted_edges_from(self.weighted_edges_list)\n",
207 " \n",
208 " def get_search_tree_node_position(self):\n",
209 " \"\"\"得到绘图的点的坐标\n",
210 " \"\"\"\n",
211 " self.dfs_search()\n",
212 " # 得到 dfs 的搜索路径图\n",
213 " paths = self.dfs_path\n",
214 " # 得到每条路径的子路径\n",
215 " path_childern = {}\n",
216 " for path in paths:\n",
217 " father = path[:-1]\n",
218 " if father in paths:\n",
219 " if father in path_childern:\n",
220 " path_childern[father].append(path)\n",
221 " else:\n",
222 " path_childern[father] = [path]\n",
223 " # 对每条子路径排序\n",
224 " o_path_childern = collections.OrderedDict(sorted(path_childern.items()))\n",
225 " # 计算每个树图中每个节点的位置\n",
226 " tree_node_position = {self.start_node:(1, 0, 2)}\n",
227 " for path, sub_paths in o_path_childern.items():\n",
228 " y_pos = -1.0/self.max_depth * len(path)\n",
229 " dx = tree_node_position[path][2]/len(sub_paths) \n",
230 " sub_paths.sort()\n",
231 " for index, e_s in enumerate(sub_paths):\n",
232 " x_pos = tree_node_position[path][0] - tree_node_position[path][2]/2 + dx/2 + dx*index\n",
233 " tree_node_position[e_s]=(x_pos,y_pos, dx)\n",
234 " self.tree_node_position = tree_node_position\n",
235 " \n",
236 " def show_edge_labels(self, ax, pos1, pos2, label):\n",
237 " (x1, y1) = pos1\n",
238 " (x2, y2) = pos2\n",
239 " (x, y) = (x1*0.5 + x2*0.5, y1*0.5 + y2*0.5)\n",
240 "\n",
241 " angle = np.arctan2(y2 - y1, x2 - x1) / (2.0 * np.pi) * 360\n",
242 " if angle > 90:\n",
243 " angle -= 180\n",
244 " if angle < - 90:\n",
245 " angle += 180\n",
246 " xy = np.array((x, y))\n",
247 " trans_angle = ax.transData.transform_angles(np.array((angle,)),\n",
248 " xy.reshape((1, 2)))[0]\n",
249 " bbox = dict(boxstyle='round',\n",
250 " ec=(1.0, 1.0, 1.0),\n",
251 " fc=(1.0, 1.0, 1.0),\n",
252 " )\n",
253 " label = str(label) \n",
254 " ax.text(x, y,\n",
255 " label,\n",
256 " size=16,\n",
257 " color='k',\n",
258 " alpha=1,\n",
259 " horizontalalignment='center',\n",
260 " verticalalignment='center',\n",
261 " rotation=trans_angle,\n",
262 " transform=ax.transData,\n",
263 " bbox=bbox,\n",
264 " zorder=1,\n",
265 " clip_on=True,\n",
266 " )\n",
267 " \n",
268 " def show_search_tree(self, \n",
269 " this_path=None, \n",
270 " show_success_color=False,\n",
271 " best_path=None\n",
272 " ):\n",
273 " \"\"\"展示搜索树\n",
274 " \"\"\"\n",
275 " # 画出树图 \n",
276 " fig, ax = plt.subplots()\n",
277 " fig.set_figwidth(15)\n",
278 " fig.set_figheight(self.max_depth*1.5)\n",
279 " plt.axis('off')\n",
280 " \n",
281 " for path, pos in self.tree_node_position.items():\n",
282 " if path[-1] == self.start_node:\n",
283 " node_color = self.start_node_color\n",
284 " edge_color = self.basic_edge_color\n",
285 " elif this_path and path in this_path:\n",
286 " if show_success_color:\n",
287 " node_color = self.success_color\n",
288 " edge_color = self.success_color\n",
289 " else:\n",
290 " node_color = self.visited_node_color\n",
291 " edge_color = self.visited_edge_color\n",
292 " elif path[-1] == self.target_node:\n",
293 " node_color = self.target_node_color\n",
294 " edge_color = self.basic_edge_color\n",
295 " else:\n",
296 " node_color = self.basic_node_color\n",
297 " edge_color = self.basic_edge_color\n",
298 " ax.scatter(pos[0], pos[1], c=node_color, s=1000,zorder=1)\n",
299 " plt.annotate(\n",
300 " path[-1],\n",
301 " xy=(pos[0], pos[1]),\n",
302 " xytext=(0, 0),\n",
303 " textcoords='offset points',\n",
304 " ha='center',\n",
305 " va='center',\n",
306 " size=15,)\n",
307 " if len(path)>1:\n",
308 " plt.plot([self.tree_node_position[path[:-1]][0],pos[0]], \n",
309 " [self.tree_node_position[path[:-1]][1],pos[1]], \n",
310 " color=edge_color,\n",
311 " zorder=0)\n",
312 " if len(path)>1:\n",
313 " label = self.weighted_edges_dic[frozenset([path[-2],path[-1]])]\n",
314 " if self.animation_type in ['greedy','a_star']:\n",
315 " label = self.help_info_weight*self.help_info[path[-1]] + self.origin_info_weight*label\n",
316 " self.show_edge_labels(ax, self.tree_node_position[path[:-1]][0:2], pos[0:2], label)\n",
317 " display.clear_output(wait=True)\n",
318 " \n",
319 " show_res_text = \"\"\n",
320 " for e_c in self.show_correct_path:\n",
321 " show_res_text += '找到一条路径: %-7s' % e_c + '。距离为:' +str(self.correct_paths[e_c]) + '\\n'\n",
322 " plt.text(0, -1.1, show_res_text, fontsize=18,horizontalalignment='left', verticalalignment='top',)\n",
323 " \n",
324 " if best_path:\n",
325 " top_text = '最终最短路径为: %-7s' % this_path + '。距离为:' +str(self.correct_paths[this_path]) + '\\n'\n",
326 " elif this_path and self.animation_type in ['dfs','bfs']:\n",
327 " top_text = '当前路径: %-7s' % this_path + '。距离为:' +str(self.path_score[this_path]) + '\\n' \n",
328 " if self.temp_best_path:\n",
329 " top_text += '当前最短路径为: %-7s' % self.temp_best_path + '。距离为:' +str(self.correct_paths[self.temp_best_path]) + '\\n'\n",
330 " else:\n",
331 " top_text = ''\n",
332 "\n",
333 " plt.text(0, 0, \n",
334 " top_text, \n",
335 " fontsize=18,\n",
336 " horizontalalignment='left', \n",
337 " verticalalignment='top',)\n",
338 " \n",
339 " if self.animation_type in ['greedy','a_star']:\n",
340 " show_greedy_text = self.generate_greedy_help_text(this_path)\n",
341 " plt.text(0, 0, show_greedy_text, fontsize=18, horizontalalignment='left', verticalalignment='top',)\n",
342 " plt.show()\n",
343 " \n",
344 " def animation_search_tree(self,search_method='dfs', help_info_weight=1, origin_info_weight=1):\n",
345 " \"\"\"动画展示搜索过程\n",
346 " \"\"\"\n",
347 " self.animation_type = search_method\n",
348 " self.show_correct_path = []\n",
349 " self.temp_best_path = None\n",
350 " if search_method == 'bfs':\n",
351 " paths = self.bfs_path\n",
352 " elif search_method == 'dfs':\n",
353 " paths = self.dfs_path\n",
354 " elif search_method == 'greedy':\n",
355 " self.greedy_search()\n",
356 " paths = self.greedy_search_path\n",
357 " elif search_method == 'a_star':\n",
358 " self.a_star_search(help_info_weight=help_info_weight, origin_info_weight=origin_info_weight)\n",
359 " paths = self.greedy_search_path\n",
360 " else:\n",
361 " paths = []\n",
362 " for e_path in paths:\n",
363 " self.show_search_tree(e_path)\n",
364 " if e_path in self.correct_paths:\n",
365 " if not self.temp_best_path:\n",
366 " self.temp_best_path = e_path\n",
367 " elif self.path_score[e_path] < self.path_score[self.temp_best_path]:\n",
368 " self.temp_best_path = e_path\n",
369 " self.show_correct_path.append(e_path)\n",
370 " self.show_search_tree(e_path, True)\n",
371 " if search_method in ['greedy', 'a_star']:\n",
372 " time.sleep(5)\n",
373 " if search_method in ['bfs', 'dfs']:\n",
374 " if self.correct_paths:\n",
375 " best_path = min(self.correct_paths, key=self.correct_paths.get)\n",
376 " self.show_search_tree(best_path, True, True)\n",
377 " \n",
378 " def animation_graph(self, search_method='bfs', help_info_weight=1, origin_info_weight=1):\n",
379 " \n",
380 " \"\"\"\n",
381 " \"\"\"\n",
382 " self.animation_type = search_method\n",
383 " self.show_correct_path = []\n",
384 " if search_method == 'bfs':\n",
385 " paths = self.bfs_path\n",
386 " elif search_method == 'dfs':\n",
387 " paths = self.dfs_path\n",
388 " elif search_method == 'greedy':\n",
389 " self.greedy_search()\n",
390 " paths = self.greedy_search_path\n",
391 " elif search_method == 'a_star':\n",
392 " self.a_star_search(help_info_weight=help_info_weight, origin_info_weight=origin_info_weight)\n",
393 " paths = self.greedy_search_path\n",
394 " else:\n",
395 " paths = []\n",
396 " for e_path in paths:\n",
397 " self.show_graph(e_path)\n",
398 " if e_path in self.correct_paths:\n",
399 " self.show_correct_path.append(e_path)\n",
400 " self.show_graph(e_path, True)\n",
401 " time.sleep(5)\n",
402 " if search_method in ['bfs', 'dfs']:\n",
403 " best_path = min(self.correct_paths, key=self.correct_paths.get)\n",
404 " self.show_graph(best_path, True, True)\n",
405 " \n",
406 " def show_graph(self, this_path='', \n",
407 " show_success_color=False,\n",
408 " best_path=None):\n",
409 " \"\"\"\n",
410 " 绘制图\n",
411 " :return:\n",
412 " \"\"\"\n",
413 " fig, ax = plt.subplots()\n",
414 " fig.set_figwidth(6)\n",
415 " fig.set_figheight(8)\n",
416 " plt.axis('off')\n",
417 "\n",
418 " # 绘制节点与边颜色\n",
419 " visited_edges = []\n",
420 " if not this_path:\n",
421 " this_path = self.start_node\n",
422 " path_node_list = list(this_path)\n",
423 " for i in range(1,len(path_node_list)):\n",
424 " visited_edges.append(frozenset([path_node_list[i],path_node_list[i-1]]))\n",
425 " \n",
426 " # 节点与标识\n",
427 " nlabels = dict(zip(self.node_list, self.node_list))\n",
428 " edge_labels = dict([((u, v,), d['weight']) for u, v, d in self.G.edges(data=True)])\n",
429 " \n",
430 " # 节点颜色变化\n",
431 " val_map = {self.target_node: self.target_node_color}\n",
432 " if path_node_list:\n",
433 " for i in path_node_list:\n",
434 " if show_success_color:\n",
435 " val_map[i] = self.success_color\n",
436 " else:\n",
437 " val_map[i] = self.visited_node_color\n",
438 " val_map[self.start_node] = self.start_node_color \n",
439 " values = [val_map.get(node, self.basic_node_color) for node in self.G.nodes()]\n",
440 "\n",
441 " # 处理边的颜色\n",
442 " edge_colors = []\n",
443 " for edge in self.G.edges():\n",
444 " # 如果边在result_red_edges,分2种情况:\n",
445 " # 如果this_path[0]/this_path[-1] 对应起始点和终点,颜色为绿色,否则颜色为红色\n",
446 " # 如果边不在result_red_edges,则初始化边的颜色为黑色\n",
447 " if frozenset(edge) in visited_edges:\n",
448 " if show_success_color:\n",
449 " edge_colors.append(self.success_color)\n",
450 " else:\n",
451 " edge_colors.append(self.visited_edge_color)\n",
452 " else:\n",
453 " edge_colors.append(self.basic_edge_color)\n",
454 "\n",
455 " # 绘制节点及其标签\n",
456 " nx.draw_networkx_nodes(self.G, self.nodes_pos, node_size=800, node_color=values, width=6.0)\n",
457 " nx.draw_networkx_labels(self.G, self.nodes_pos, nlabels, font_size=20)\n",
458 " # 绘制边及其标签\n",
459 " nx.draw_networkx_edges(self.G, self.nodes_pos, edge_color=edge_colors, width=2.0, alpha=1.0)\n",
460 " nx.draw_networkx_edge_labels(self.G, self.nodes_pos, edge_labels=edge_labels, font_size=18)\n",
461 "\n",
462 " display.clear_output(wait=True)\n",
463 " # show_text = \"\"\n",
464 " # for e_c in self.show_correct_path:\n",
465 " # show_text += '找到一条路径: %-7s' % e_c + '。距离为:' +str(self.correct_paths[e_c]) + '\\n'\n",
466 " # plt.text(0, -2.6, show_text, fontsize=18, horizontalalignment='left', verticalalignment='top', )\n",
467 " \n",
468 "# if best_path:\n",
469 "# top_text = '最佳路径为: %-7s' % this_path + '。 距离为:' +str(self.correct_paths[this_path]) + '\\n'\n",
470 "# elif this_path and self.animation_type in ['dfs','bfs']:\n",
471 "# top_text = '当前路径: %-7s' % this_path + '。 距离为:' +str(self.cal_dis(this_path)) + '\\n'\n",
472 "# else:\n",
473 "# top_text = ''\n",
474 "# plt.text(0, 0, \n",
475 "# top_text, \n",
476 "# fontsize=18,\n",
477 "# horizontalalignment='left', \n",
478 "# verticalalignment='top',)\n",
479 " plt.show()\n",
480 " \n",
481 " def _dfs_helper(self, G, node, father, target_node,level, res, path):\n",
482 " path+=str(node)\n",
483 " if len(path)>1:\n",
484 " self.path_score[path] = self.path_score[path[:-1]] + self.weighted_edges_dic[frozenset([path[-2],path[-1]])]\n",
485 " res.append(path)\n",
486 " # 找到目标,停止搜索\n",
487 " if node==target_node:\n",
488 " return\n",
489 " if level< self.max_depth:\n",
490 " for neighbor in sorted(G[node]):\n",
491 " if str(neighbor) not in path:\n",
492 " self._dfs_helper(G, neighbor, node, target_node, level+1, res, path)\n",
493 " \n",
494 " def dfs_search(self):\n",
495 " dfs_path=[]\n",
496 " this_path=''\n",
497 " if self.start_node:\n",
498 " self._dfs_helper(self.G, self.start_node, None, self.target_node, 0, dfs_path, this_path)\n",
499 " self.dfs_path = dfs_path\n",
500 " for p in dfs_path:\n",
501 " if p[-1]==self.target_node and p not in self.correct_paths:\n",
502 " self.correct_paths[p] = self.cal_dis(p) \n",
503 " \n",
504 " def bfs_search(self):\n",
505 " to_search=[self.start_node]\n",
506 " bfs_path = []\n",
507 " bfs_correct_path = []\n",
508 " depth = 0\n",
509 " while to_search:\n",
510 " this_search = to_search.pop(0)\n",
511 " if len(this_search)>self.max_depth+1 :\n",
512 " break\n",
513 " bfs_path.append(this_search)\n",
514 " if this_search[-1]==self.target_node:\n",
515 " bfs_correct_path.append(this_search)\n",
516 " continue\n",
517 " for ne in sorted(self.G[this_search[-1]]):\n",
518 " if ne not in this_search:\n",
519 " to_search.append(this_search+ne)\n",
520 " self.bfs_path = bfs_path\n",
521 " for p in bfs_path:\n",
522 " if p[-1]==self.target_node and p not in self.correct_paths:\n",
523 " self.correct_paths[p] = self.cal_dis(p)\n",
524 " \n",
525 " def greedy_search(self, help_info_weight=1, origin_info_weight=0):\n",
526 " self.help_info_weight = help_info_weight\n",
527 " self.origin_info_weight = origin_info_weight\n",
528 " search_path = self.start_node\n",
529 " # 存储每一步的可选项及其分数,用来在动态演示时显示出来\n",
530 " search_scores = {}\n",
531 " while len(search_path) <= self.max_depth:\n",
532 " this_node = search_path[-1]\n",
533 " neighbour_nodes = [e_n for e_n in sorted(self.G[this_node]) if e_n not in search_path]\n",
534 " if len(neighbour_nodes) == 0:\n",
535 " search_scores[search_path]={}\n",
536 " break\n",
537 " if self.help_info:\n",
538 " scores = {e_n:help_info_weight*self.help_info[e_n]+origin_info_weight*self.weighted_edges_dic[frozenset([this_node,e_n])] for e_n in neighbour_nodes }\n",
539 " else:\n",
540 " scores = {e_n:self.weighted_edges_dic[frozenset([this_node,e_n])]\n",
541 " for e_n in neighbour_nodes }\n",
542 " search_scores[search_path]=scores\n",
543 " nearest_node = min(scores, key=scores.get)\n",
544 " search_path += nearest_node\n",
545 " if nearest_node == self.target_node:\n",
546 " break\n",
547 " self.greedy_search_path = [search_path[0:index+1] for index in range(len(search_path))]\n",
548 " self.search_scores = search_scores\n",
549 " \n",
550 " def a_star_search(self, help_info_weight=1, origin_info_weight=1):\n",
551 " self.greedy_search(help_info_weight, origin_info_weight)\n",
552 " \n",
553 "\n",
554 " def generate_greedy_help_text(self,path):\n",
555 " if path[-1] == self.target_node:\n",
556 " return '抵达目标节点' + str(self.target_node)\n",
557 " elif path not in self.search_scores:\n",
558 " return '抵达最大搜索深度,未找到目标节点'\n",
559 " \n",
560 " base_text = '当前可选的子节点及其信息值为 \\n'+ \\\n",
561 " str(self.search_scores[path]) + '\\n'\n",
562 " if self.target_node in self.search_scores[path]:\n",
563 " return base_text + '当前可选的子节点包含了目标节点,\\n所以选择目标节点'\n",
564 " elif len(self.search_scores[path]) == 1:\n",
565 " return base_text + '因为只有一个子节点,所以选择此节点'\n",
566 " else:\n",
567 " return base_text + '因为'+ \\\n",
568 " str(min(self.search_scores[path], key=self.search_scores[path].get)) + \\\n",
569 " '的值最小,所以选择此节点'\n",
570 " \n",
571 " def cal_dis(self,path):\n",
572 " dis = 0\n",
573 " if len(path) > 1:\n",
574 " for i in range(len(path)-1):\n",
575 " dis += self.weighted_edges_dic[frozenset([path[i],path[i+1]])]\n",
576 " return dis\n",
577 " "
578 ]
579 }
580 ],
581 "metadata": {
582 "kernelspec": {
583 "display_name": "Python 3",
584 "language": "python",
585 "name": "python3"
586 },
587 "language_info": {
588 "codemirror_mode": {
589 "name": "ipython",
590 "version": 3
591 },
592 "file_extension": ".py",
593 "mimetype": "text/x-python",
594 "name": "python",
595 "nbconvert_exporter": "python",
596 "pygments_lexer": "ipython3",
597 "version": "3.5.2"
598 }
599 },
600 "nbformat": 4,
601 "nbformat_minor": 2
602 }
Binary diff not shown