4b0cb67
joyvan 6 years ago
29 changed file(s) with 13591 addition(s) and 193 deletion(s). Raw diff Collapse all Expand all
0 {
1 "cells": [
2 {
3 "cell_type": "markdown",
4 "metadata": {},
5 "source": [
6 "# 1. Python 基础"
7 ]
8 },
9 {
10 "cell_type": "markdown",
11 "metadata": {},
12 "source": [
13 "## 1.1 Python 简介\n",
14 "\n",
15 "<img src=\"https://www.python.org/static/img/python-logo@2x.png\" width=300/> \n",
16 "\n",
17 "**Python** 是一个高层次的结合了解释性、编译性、互动性和面向对象的脚本语言。\n",
18 "\n",
19 "- **Python** 是一种解释型语言: 开发过程中没有了编译这个环节。\n",
20 "\n",
21 "- **Python** 是交互式语言: 可以在一个 **Python** 提示符 `>>>` 后直接执行代码。\n",
22 "\n",
23 "- **Python** 是面向对象语言: **Python** 支持面向对象的风格或代码封装在对象的编程技术。\n",
24 "\n",
25 "- **Python** 对初学者非常友好:**Python** 语法简单,可以快速上手,但异常强大,应用也十分广泛,从 `web` 开发,网络爬虫到机器学习,人工智能,金融量化分析都有广泛的应用。\n",
26 " \n",
27 "### 第一行 Python 代码"
28 ]
29 },
30 {
31 "cell_type": "code",
32 "execution_count": null,
33 "metadata": {},
34 "outputs": [],
35 "source": [
36 "print (\"hello world!\")"
37 ]
38 },
39 {
40 "cell_type": "markdown",
41 "metadata": {},
42 "source": [
43 "## 1.2 数据类型\n",
44 "\n",
45 "| 类型| 例子|\n",
46 "| ----- | ----- |\n",
47 "| 整数 | `-100` |\n",
48 "| 浮点数 | `3.1416` |\n",
49 "| 字符串 | `'hello'` |\n",
50 "| 列表 | `[1, 1.2, 'hello']` |\n",
51 "| 字典 | `{'dogs': 5, 'pigs': 3}`|\n",
52 "| 长整型 | `1000000000000L`\n",
53 "| 布尔型 | `True, False`\n",
54 "| 元组 | `('ring', 1000)`\n",
55 "| 集合 | `{1, 2, 3}`"
56 ]
57 },
58 {
59 "cell_type": "markdown",
60 "metadata": {},
61 "source": [
62 "使用`type()`函数来查看变量类型:"
63 ]
64 },
65 {
66 "cell_type": "code",
67 "execution_count": null,
68 "metadata": {},
69 "outputs": [],
70 "source": [
71 "a = 1\n",
72 "type(a)"
73 ]
74 },
75 {
76 "cell_type": "markdown",
77 "metadata": {},
78 "source": [
79 "在 **Python** 中运算是有优先级的,优先级即算术的先后顺序,比如“先乘除后加减”和“先算括号里面的”都是两种优先级的规则,优先级从高到低排列如下:\n",
80 "\n",
81 "- `( )` 括号\n",
82 "- `**` 幂指数运算\n",
83 "- `* / // %` 乘,除,整数除法,取余运算\n",
84 "- `+ -` 加减\n"
85 ]
86 },
87 {
88 "cell_type": "code",
89 "execution_count": null,
90 "metadata": {},
91 "outputs": [],
92 "source": [
93 "a = 4\n",
94 "b = 3\n",
95 "print(\"加:\", a + b)\n",
96 "print(\"减:\", a - b)\n",
97 "print(\"乘:\", a * b)\n",
98 "print(\"除:\", a / b)\n",
99 "print('幂:', a ** b)\n",
100 "print('取余', a % b)\n",
101 "print('取商:', a // b)"
102 ]
103 },
104 {
105 "cell_type": "markdown",
106 "metadata": {},
107 "source": [
108 "### 常见的数学函数\n",
109 "\n",
110 "绝对值:"
111 ]
112 },
113 {
114 "cell_type": "code",
115 "execution_count": null,
116 "metadata": {},
117 "outputs": [],
118 "source": [
119 "abs(-12.4)"
120 ]
121 },
122 {
123 "cell_type": "markdown",
124 "metadata": {},
125 "source": [
126 "保留小数点位数:"
127 ]
128 },
129 {
130 "cell_type": "code",
131 "execution_count": null,
132 "metadata": {},
133 "outputs": [],
134 "source": [
135 "round(21.6445, 2)"
136 ]
137 },
138 {
139 "cell_type": "markdown",
140 "metadata": {},
141 "source": [
142 "最大最小值:"
143 ]
144 },
145 {
146 "cell_type": "code",
147 "execution_count": null,
148 "metadata": {},
149 "outputs": [],
150 "source": [
151 "print (min(2, 3, 4, 5))\n",
152 "print (max(2, 4, 3))"
153 ]
154 },
155 {
156 "cell_type": "markdown",
157 "metadata": {},
158 "source": [
159 "### 类型转换 \n",
160 "浮点数转整型,只保留整数部分:"
161 ]
162 },
163 {
164 "cell_type": "code",
165 "execution_count": null,
166 "metadata": {},
167 "outputs": [],
168 "source": [
169 "print (int(-3.32))"
170 ]
171 },
172 {
173 "cell_type": "markdown",
174 "metadata": {},
175 "source": [
176 "整型转浮点型:"
177 ]
178 },
179 {
180 "cell_type": "code",
181 "execution_count": null,
182 "metadata": {},
183 "outputs": [],
184 "source": [
185 "print (float(1))"
186 ]
187 },
188 {
189 "cell_type": "markdown",
190 "metadata": {},
191 "source": [
192 "## 1.3 字符串"
193 ]
194 },
195 {
196 "cell_type": "markdown",
197 "metadata": {},
198 "source": [
199 "使用一对单引号`' '`或者双引号`\" \"`生成字符串。"
200 ]
201 },
202 {
203 "cell_type": "code",
204 "execution_count": null,
205 "metadata": {},
206 "outputs": [],
207 "source": [
208 "s = \"hello, world\"\n",
209 "print (s)"
210 ]
211 },
212 {
213 "cell_type": "code",
214 "execution_count": null,
215 "metadata": {},
216 "outputs": [],
217 "source": [
218 "s = 'hello, world'\n",
219 "print (s)"
220 ]
221 },
222 {
223 "cell_type": "markdown",
224 "metadata": {},
225 "source": [
226 "### 常见的操作\n",
227 "\n",
228 "**加法**:"
229 ]
230 },
231 {
232 "cell_type": "code",
233 "execution_count": null,
234 "metadata": {},
235 "outputs": [],
236 "source": [
237 "a = 'hello'\n",
238 "b = 'world'\n",
239 "a + b"
240 ]
241 },
242 {
243 "cell_type": "markdown",
244 "metadata": {},
245 "source": [
246 "**乘法**:"
247 ]
248 },
249 {
250 "cell_type": "code",
251 "execution_count": null,
252 "metadata": {},
253 "outputs": [],
254 "source": [
255 "c = a * 3\n",
256 "c"
257 ]
258 },
259 {
260 "cell_type": "markdown",
261 "metadata": {},
262 "source": [
263 "**分割**: \n",
264 "`s.split()` 将字符串 s 按照空格(包括多个空格,制表符`\\t`,换行符`\\n`等)分割,并返回所有分割得到的字符串。"
265 ]
266 },
267 {
268 "cell_type": "code",
269 "execution_count": null,
270 "metadata": {},
271 "outputs": [],
272 "source": [
273 "line = \"1 2 3 4 5\"\n",
274 "numbers = line.split()\n",
275 "print (numbers)"
276 ]
277 },
278 {
279 "cell_type": "markdown",
280 "metadata": {},
281 "source": [
282 "**连接**: \n",
283 "与分割相反,`s.join(sequence)` 的作用是以 `s` 为连接符将序列 `sequence` 中的元素连接起来,并返回连接后得到的新字符串。\n"
284 ]
285 },
286 {
287 "cell_type": "code",
288 "execution_count": null,
289 "metadata": {},
290 "outputs": [],
291 "source": [
292 "s = ' '\n",
293 "s.join(numbers)"
294 ]
295 },
296 {
297 "cell_type": "markdown",
298 "metadata": {},
299 "source": [
300 "**替换**: \n",
301 "`s.replace(part1, part2)` 将字符串 `s` 中指定的部分 `part1` 替换成想要的部分 `part2`,并返回新的字符串。"
302 ]
303 },
304 {
305 "cell_type": "code",
306 "execution_count": null,
307 "metadata": {},
308 "outputs": [],
309 "source": [
310 "s = \"hello world\"\n",
311 "s.replace('world', 'python')"
312 ]
313 },
314 {
315 "cell_type": "markdown",
316 "metadata": {},
317 "source": [
318 "**大小写转换**: \n",
319 "\n",
320 "`s.upper()` 方法返回一个将 `s` 中的字母全部大写的新字符串。\n",
321 "\n",
322 "`s.lower()` 方法返回一个将 `s` 中的字母全部小写的新字符串。"
323 ]
324 },
325 {
326 "cell_type": "code",
327 "execution_count": null,
328 "metadata": {},
329 "outputs": [],
330 "source": [
331 "\"hello world\".upper()"
332 ]
333 },
334 {
335 "cell_type": "code",
336 "execution_count": null,
337 "metadata": {},
338 "outputs": [],
339 "source": [
340 "s = \"HELLO WORLD\"\n",
341 "print (s.lower())\n",
342 "\n",
343 "# 不会改变原来s的值\n",
344 "print (s)"
345 ]
346 },
347 {
348 "cell_type": "markdown",
349 "metadata": {},
350 "source": [
351 "**字符串的长度**:"
352 ]
353 },
354 {
355 "cell_type": "code",
356 "execution_count": null,
357 "metadata": {},
358 "outputs": [],
359 "source": [
360 "len(s)"
361 ]
362 },
363 {
364 "cell_type": "markdown",
365 "metadata": {},
366 "source": [
367 "## 1.4 索引和分片\n",
368 "\n",
369 "### 索引\n",
370 "\n",
371 "对于一个有序序列,可以通过索引的方法来访问对应位置的值。字符串便是一个有序序列,**Python** 使用 **下标** 来对有序序列进行索引。索引是从 `0` 开始的,所以索引 `0` 对应与序列的第 `1` 个元素。"
372 ]
373 },
374 {
375 "cell_type": "code",
376 "execution_count": null,
377 "metadata": {},
378 "outputs": [],
379 "source": [
380 "s = \"hello\"\n",
381 "s[0]"
382 ]
383 },
384 {
385 "cell_type": "markdown",
386 "metadata": {},
387 "source": [
388 "除了正向索引,**Python** 还引入了负索引值的用法,即从后向前开始计数,例如,索引 `-1` 表示倒数第 `1` 个元素:"
389 ]
390 },
391 {
392 "cell_type": "code",
393 "execution_count": null,
394 "metadata": {},
395 "outputs": [],
396 "source": [
397 "s[-1]"
398 ]
399 },
400 {
401 "cell_type": "markdown",
402 "metadata": {},
403 "source": [
404 "单个索引大于等于字符串的长度时,会报错:"
405 ]
406 },
407 {
408 "cell_type": "code",
409 "execution_count": null,
410 "metadata": {},
411 "outputs": [],
412 "source": [
413 "s[6]"
414 ]
415 },
416 {
417 "cell_type": "markdown",
418 "metadata": {},
419 "source": [
420 "### 分片 \n",
421 "\n",
422 "分片用来从序列中提取出想要的子序列,其用法为: \n",
423 "\n",
424 " var[start_index: stop_index: step] \n",
425 "\n",
426 "其范围包括 `start_index` ,但不包括 `stop_index` ,即 `[start_index, stop_index)`, `step` 表示取值间隔大小,如果没有默认为`1`。 "
427 ]
428 },
429 {
430 "cell_type": "code",
431 "execution_count": null,
432 "metadata": {},
433 "outputs": [],
434 "source": [
435 "s = \"hello\"\n",
436 "s[::2]"
437 ]
438 },
439 {
440 "cell_type": "markdown",
441 "metadata": {},
442 "source": [
443 "## 1.5 列表\n",
444 "\n",
445 "列表是一个有序的序列。\n",
446 "\n",
447 "列表用一对 `[ ]` 生成,中间的元素用 `,` 隔开,其中的元素不需要是同一类型,同时列表的长度也不固定。"
448 ]
449 },
450 {
451 "cell_type": "code",
452 "execution_count": null,
453 "metadata": {},
454 "outputs": [],
455 "source": [
456 "l = [1, 2.0, 'hello']\n",
457 "print (l)"
458 ]
459 },
460 {
461 "cell_type": "markdown",
462 "metadata": {},
463 "source": [
464 "空列表可以用 `[]` 或者 `list()` 生成:"
465 ]
466 },
467 {
468 "cell_type": "code",
469 "execution_count": null,
470 "metadata": {},
471 "outputs": [],
472 "source": [
473 "empty_list = []\n",
474 "empty_list"
475 ]
476 },
477 {
478 "cell_type": "code",
479 "execution_count": null,
480 "metadata": {},
481 "outputs": [],
482 "source": [
483 "empty_list = list()\n",
484 "empty_list\n"
485 ]
486 },
487 {
488 "cell_type": "markdown",
489 "metadata": {},
490 "source": [
491 "### 列表的常见操作\n",
492 "\n",
493 " **长度**:用 `len` 查看列表长度"
494 ]
495 },
496 {
497 "cell_type": "code",
498 "execution_count": null,
499 "metadata": {},
500 "outputs": [],
501 "source": [
502 "l = [1, 2.0, 'hello']\n",
503 "len(l)"
504 ]
505 },
506 {
507 "cell_type": "markdown",
508 "metadata": {},
509 "source": [
510 "**加法**: 相当于将两个列表按顺序连接"
511 ]
512 },
513 {
514 "cell_type": "code",
515 "execution_count": null,
516 "metadata": {},
517 "outputs": [],
518 "source": [
519 "a = [1, 2, 3]\n",
520 "b = [3.2, 'hello']\n",
521 "a + b"
522 ]
523 },
524 {
525 "cell_type": "markdown",
526 "metadata": {},
527 "source": [
528 "**乘法**:列表与整数相乘,相当于将列表重复相加"
529 ]
530 },
531 {
532 "cell_type": "code",
533 "execution_count": null,
534 "metadata": {},
535 "outputs": [],
536 "source": [
537 "a * 3"
538 ]
539 },
540 {
541 "cell_type": "markdown",
542 "metadata": {},
543 "source": [
544 "### 索引和分片\n",
545 "列表和字符串一样可以通过索引和分片来查看它的元素。"
546 ]
547 },
548 {
549 "cell_type": "markdown",
550 "metadata": {},
551 "source": [
552 "**索引**:"
553 ]
554 },
555 {
556 "cell_type": "code",
557 "execution_count": null,
558 "metadata": {},
559 "outputs": [],
560 "source": [
561 "a = [10, 11, 12, 13, 14]\n",
562 "a[0]"
563 ]
564 },
565 {
566 "cell_type": "markdown",
567 "metadata": {},
568 "source": [
569 "**反向索引**:"
570 ]
571 },
572 {
573 "cell_type": "code",
574 "execution_count": null,
575 "metadata": {},
576 "outputs": [],
577 "source": [
578 "a[-1]"
579 ]
580 },
581 {
582 "cell_type": "markdown",
583 "metadata": {},
584 "source": [
585 "**分片**:"
586 ]
587 },
588 {
589 "cell_type": "code",
590 "execution_count": null,
591 "metadata": {},
592 "outputs": [],
593 "source": [
594 "a[2:-1]"
595 ]
596 },
597 {
598 "cell_type": "markdown",
599 "metadata": {},
600 "source": [
601 "### **添加元素**\n",
602 "\n",
603 "**append**:向列表添加单个元素 \n",
604 "`l.append(ob)` 将元素 `ob` 添加到列表 `l` 的最后。"
605 ]
606 },
607 {
608 "cell_type": "code",
609 "execution_count": null,
610 "metadata": {},
611 "outputs": [],
612 "source": [
613 "a = [10, 11, 12]\n",
614 "a.append(11)\n",
615 "print (a)"
616 ]
617 },
618 {
619 "cell_type": "markdown",
620 "metadata": {},
621 "source": [
622 "`append` 每次只添加一个元素,并不会因为这个元素是序列而将其展开:"
623 ]
624 },
625 {
626 "cell_type": "code",
627 "execution_count": null,
628 "metadata": {},
629 "outputs": [],
630 "source": [
631 "a = [10, 11, 12]\n",
632 "a.append(['a','b'])\n",
633 "print (a)"
634 ]
635 },
636 {
637 "cell_type": "markdown",
638 "metadata": {},
639 "source": [
640 "**extend**: 向列表添加序列元素 \n",
641 "`l.extend(lst)` 将序列 `lst` 的元素依次添加到列表 `l` 的最后,作用相当于 `l += lst`。"
642 ]
643 },
644 {
645 "cell_type": "code",
646 "execution_count": null,
647 "metadata": {},
648 "outputs": [],
649 "source": [
650 "a = [10, 11, 12]\n",
651 "a.extend(['a','b'])\n",
652 "print (a)"
653 ]
654 },
655 {
656 "cell_type": "markdown",
657 "metadata": {},
658 "source": [
659 "**insert**: 插入元素 \n",
660 "`l.insert(idx, ob)` 在索引 `idx` 处插入 `ob` ,之后的元素依次后移。"
661 ]
662 },
663 {
664 "cell_type": "code",
665 "execution_count": null,
666 "metadata": {},
667 "outputs": [],
668 "source": [
669 "a = [10, 11, 12, 13, 11]\n",
670 "# 在索引 3 插入 'a'\n",
671 "a.insert(3, 'a')\n",
672 "print (a)"
673 ]
674 },
675 {
676 "cell_type": "markdown",
677 "metadata": {},
678 "source": [
679 "### **删除元素** "
680 ]
681 },
682 {
683 "cell_type": "markdown",
684 "metadata": {},
685 "source": [
686 "**del**:根据下标进行删除"
687 ]
688 },
689 {
690 "cell_type": "code",
691 "execution_count": null,
692 "metadata": {},
693 "outputs": [],
694 "source": [
695 "# 根据下标进行删除\n",
696 "a = [1002, 'a', 'b', 'c']\n",
697 "del a[0]\n",
698 "print (a)"
699 ]
700 },
701 {
702 "cell_type": "markdown",
703 "metadata": {},
704 "source": [
705 "**pop**:弹出元素 \n",
706 "`l.pop(idx)` 会将索引 `idx` 处的元素删除,并返回这个元素。未指定 `idx` 时,默认为列表最后一个元素。"
707 ]
708 },
709 {
710 "cell_type": "code",
711 "execution_count": null,
712 "metadata": {},
713 "outputs": [],
714 "source": [
715 "a = [1002, 'a', 'b', 'c']\n",
716 "a.pop()\n",
717 "print (a)"
718 ]
719 },
720 {
721 "cell_type": "markdown",
722 "metadata": {},
723 "source": [
724 "**remove**:根据元素的值进行删除 \n",
725 "`l.remove(ob)` 会将列表中第一个出现的 `ob` 删除,如果 `ob` 不在 `l` 中会报错。"
726 ]
727 },
728 {
729 "cell_type": "code",
730 "execution_count": null,
731 "metadata": {},
732 "outputs": [],
733 "source": [
734 "a = [1002, 'a', 'b', 'c', 'b']\n",
735 "a.remove(\"b\")\n",
736 "print (a)"
737 ]
738 },
739 {
740 "cell_type": "markdown",
741 "metadata": {},
742 "source": [
743 "### **测试从属关系**\n",
744 "用 `in` 来看某个元素是否在某个序列(不仅仅是列表)中;\n",
745 "\n",
746 "用`not in`来判断是否不在某个序列中。"
747 ]
748 },
749 {
750 "cell_type": "code",
751 "execution_count": null,
752 "metadata": {},
753 "outputs": [],
754 "source": [
755 "a = [10, 11, 12, 13, 14]\n",
756 "print (10 in a)\n",
757 "print (10 not in a)"
758 ]
759 },
760 {
761 "cell_type": "markdown",
762 "metadata": {},
763 "source": [
764 "用 `index` 查找某个元素在列表中的位置:\n",
765 "\n",
766 "`l.index(ob)` 返回列表中元素 `ob` 第一次出现的索引位置,如果 `ob` 不在 `l` 中会报错。"
767 ]
768 },
769 {
770 "cell_type": "code",
771 "execution_count": null,
772 "metadata": {},
773 "outputs": [],
774 "source": [
775 "a = [11, 12, 13, 12, 11]\n",
776 "a.index(12)"
777 ]
778 },
779 {
780 "cell_type": "markdown",
781 "metadata": {},
782 "source": [
783 "`count` 查找列表中某个元素出现的次数:"
784 ]
785 },
786 {
787 "cell_type": "code",
788 "execution_count": null,
789 "metadata": {},
790 "outputs": [],
791 "source": [
792 "a = [11, 12, 13, 12, 11]\n",
793 "a.count(11)"
794 ]
795 },
796 {
797 "cell_type": "markdown",
798 "metadata": {},
799 "source": [
800 "### **修改元素** \n",
801 "\n",
802 "修改元素的时候,要通过下标来确定要修改的是哪个元素,然后才能进行修改"
803 ]
804 },
805 {
806 "cell_type": "code",
807 "execution_count": null,
808 "metadata": {},
809 "outputs": [],
810 "source": [
811 "a = [10, 11, 12, 13, 11]\n",
812 "a[0] = \"a\"\n",
813 "a"
814 ]
815 },
816 {
817 "cell_type": "markdown",
818 "metadata": {},
819 "source": [
820 "### **排序**"
821 ]
822 },
823 {
824 "cell_type": "markdown",
825 "metadata": {},
826 "source": [
827 "`sort`方法将 `list` 按特定顺序重新排列,默认为由小到大,参数 `reverse=True` 可改为倒序,由大到小"
828 ]
829 },
830 {
831 "cell_type": "code",
832 "execution_count": null,
833 "metadata": {},
834 "outputs": [],
835 "source": [
836 "# 从小到大排序\n",
837 "a = [10, 1, 11, 13, 11, 2]\n",
838 "a.sort()\n",
839 "print (a)"
840 ]
841 },
842 {
843 "cell_type": "code",
844 "execution_count": null,
845 "metadata": {},
846 "outputs": [],
847 "source": [
848 "# 从大到小排序\n",
849 "a = [10, 1, 11, 13, 11, 2]\n",
850 "a.sort(reverse=True)\n",
851 "print (a)"
852 ]
853 },
854 {
855 "cell_type": "markdown",
856 "metadata": {},
857 "source": [
858 "如果不想改变原来列表中的值,可以使用 `sorted` 函数:"
859 ]
860 },
861 {
862 "cell_type": "code",
863 "execution_count": null,
864 "metadata": {},
865 "outputs": [],
866 "source": [
867 "a = [10, 1, 11, 13, 11, 2]\n",
868 "b = sorted(a)\n",
869 "print (\"a:\",a)\n",
870 "print (\"b:\",b)"
871 ]
872 },
873 {
874 "cell_type": "markdown",
875 "metadata": {},
876 "source": [
877 "### 列表反向"
878 ]
879 },
880 {
881 "cell_type": "markdown",
882 "metadata": {},
883 "source": [
884 "`l.reverse()` 会将列表中的元素从后向前排列。"
885 ]
886 },
887 {
888 "cell_type": "code",
889 "execution_count": null,
890 "metadata": {},
891 "outputs": [],
892 "source": [
893 "a = [10, 1, 11, 13, 11, 2]\n",
894 "a.reverse()\n",
895 "print (a)"
896 ]
897 },
898 {
899 "cell_type": "markdown",
900 "metadata": {},
901 "source": [
902 "如果不想改变原来列表中的值,可以使用分片:"
903 ]
904 },
905 {
906 "cell_type": "code",
907 "execution_count": null,
908 "metadata": {},
909 "outputs": [],
910 "source": [
911 "a = [10, 1, 11, 13, 11, 2]\n",
912 "b = a[::-1]\n",
913 "print (\"a:\",a)\n",
914 "print (\"b:\",b)"
915 ]
916 },
917 {
918 "cell_type": "markdown",
919 "metadata": {},
920 "source": [
921 "## 1.6 字典\n",
922 "\n",
923 "字典 `dictionary` ,在一些编程语言中也称为 `hash` , `map` ,是一种由键值对组成的数据结构。\n",
924 "\n",
925 "顾名思义,我们把键想象成字典中的单词,值想象成词对应的定义,那么——\n",
926 "\n",
927 "一个词可以对应一个或者多个定义,但是这些定义只能通过这个词来进行查询。\n",
928 "\n",
929 "**Python** 使用`key: value`这样的结构来表示字典中的元素结构。"
930 ]
931 },
932 {
933 "cell_type": "markdown",
934 "metadata": {},
935 "source": [
936 "### 空字典\n",
937 "\n",
938 "**Python** 使用 `{}` 或者 `dict()` 来创建一个空的字典:"
939 ]
940 },
941 {
942 "cell_type": "code",
943 "execution_count": null,
944 "metadata": {},
945 "outputs": [],
946 "source": [
947 "a = {}\n",
948 "type(a)"
949 ]
950 },
951 {
952 "cell_type": "code",
953 "execution_count": null,
954 "metadata": {},
955 "outputs": [],
956 "source": [
957 "a = dict()\n",
958 "type(a)"
959 ]
960 },
961 {
962 "cell_type": "markdown",
963 "metadata": {},
964 "source": [
965 "### 插入键值"
966 ]
967 },
968 {
969 "cell_type": "code",
970 "execution_count": null,
971 "metadata": {},
972 "outputs": [],
973 "source": [
974 "a[\"one\"] = \"this is number 1\"\n",
975 "a[\"two\"] = \"this is number 2\"\n",
976 "a[\"three\"] = \"this is number 3\"\n",
977 "a"
978 ]
979 },
980 {
981 "cell_type": "markdown",
982 "metadata": {},
983 "source": [
984 "**注意:** \n",
985 "1.字典的键必须是数字、字符串、元组等,不能是列表、字典、集合。 \n",
986 "2.字典没有顺序:当我们 `print` 一个字典时,**Python** 并不一定按照插入键值的先后顺序进行显示,因为字典中的键本身不一定是有序的。"
987 ]
988 },
989 {
990 "cell_type": "code",
991 "execution_count": null,
992 "metadata": {},
993 "outputs": [],
994 "source": [
995 "# 查看键值\n",
996 "a['one']"
997 ]
998 },
999 {
1000 "cell_type": "markdown",
1001 "metadata": {},
1002 "source": [
1003 "### 更新键值"
1004 ]
1005 },
1006 {
1007 "cell_type": "code",
1008 "execution_count": null,
1009 "metadata": {},
1010 "outputs": [],
1011 "source": [
1012 "a[\"one\"] = \"this is number 1, too\"\n",
1013 "a"
1014 ]
1015 },
1016 {
1017 "cell_type": "markdown",
1018 "metadata": {},
1019 "source": [
1020 "### `get`方法"
1021 ]
1022 },
1023 {
1024 "cell_type": "markdown",
1025 "metadata": {},
1026 "source": [
1027 "用键可以找到该键对应的值,但是当字典中没有这个键的时候,**Python** 会报错,这时候可以使用字典的 `get` 方法来处理这种情况,其用法如下:\n",
1028 "\n",
1029 "`d.get(key, default = None)`\n",
1030 "\n",
1031 "返回字典中键 `key` 对应的值,如果没有这个键,返回 `default` 指定的值(默认是 `None` )。"
1032 ]
1033 },
1034 {
1035 "cell_type": "code",
1036 "execution_count": null,
1037 "metadata": {},
1038 "outputs": [],
1039 "source": [
1040 "a = {}\n",
1041 "a[\"one\"] = \"this is number 1\"\n",
1042 "a[\"two\"] = \"this is number 2\"\n",
1043 "\n",
1044 "a.get(\"three\", \"undefined\")"
1045 ]
1046 },
1047 {
1048 "cell_type": "markdown",
1049 "metadata": {},
1050 "source": [
1051 "### `keys` 方法,`values` 方法和`items` 方法"
1052 ]
1053 },
1054 {
1055 "cell_type": "markdown",
1056 "metadata": {},
1057 "source": [
1058 "`d.keys()` :返回一个由所有键组成的列表;\n",
1059 "\n",
1060 "`d.values()` :返回一个由所有值组成的列表;\n",
1061 "\n",
1062 "`d.items()` :返回一个由所有键值对元组组成的列表。"
1063 ]
1064 },
1065 {
1066 "cell_type": "code",
1067 "execution_count": null,
1068 "metadata": {},
1069 "outputs": [],
1070 "source": [
1071 "a = {}\n",
1072 "a[\"one\"] = \"this is number 1\"\n",
1073 "a[\"two\"] = \"this is number 2\"\n",
1074 "\n",
1075 "a.keys()"
1076 ]
1077 },
1078 {
1079 "cell_type": "code",
1080 "execution_count": null,
1081 "metadata": {},
1082 "outputs": [],
1083 "source": [
1084 "a.values()"
1085 ]
1086 },
1087 {
1088 "cell_type": "code",
1089 "execution_count": null,
1090 "metadata": {},
1091 "outputs": [],
1092 "source": [
1093 "a.items()"
1094 ]
1095 },
1096 {
1097 "cell_type": "markdown",
1098 "metadata": {},
1099 "source": [
1100 "## 1.7 元组\n",
1101 "\n",
1102 "元组`Tuple`也是个有序序列,但是元组是不可变的,用`()`生成。"
1103 ]
1104 },
1105 {
1106 "cell_type": "code",
1107 "execution_count": null,
1108 "metadata": {},
1109 "outputs": [],
1110 "source": [
1111 "# 生成元组\n",
1112 "a = ()\n",
1113 "type(a)"
1114 ]
1115 },
1116 {
1117 "cell_type": "markdown",
1118 "metadata": {},
1119 "source": [
1120 "生成只含有单个元素的元组时,采用下列方式定义:"
1121 ]
1122 },
1123 {
1124 "cell_type": "code",
1125 "execution_count": null,
1126 "metadata": {},
1127 "outputs": [],
1128 "source": [
1129 "# 生成元组\n",
1130 "a = (1,)\n",
1131 "type(a)"
1132 ]
1133 },
1134 {
1135 "cell_type": "markdown",
1136 "metadata": {},
1137 "source": [
1138 "元组是**不可变**的,修改元组元素时会报错:"
1139 ]
1140 },
1141 {
1142 "cell_type": "code",
1143 "execution_count": null,
1144 "metadata": {},
1145 "outputs": [],
1146 "source": [
1147 "a = (10, 11, 12, 13, 14)\n",
1148 "a[0] = 1\n",
1149 "a"
1150 ]
1151 },
1152 {
1153 "cell_type": "markdown",
1154 "metadata": {},
1155 "source": [
1156 "## 1.8 集合"
1157 ]
1158 },
1159 {
1160 "cell_type": "markdown",
1161 "metadata": {},
1162 "source": [
1163 "之前看到的列表和字符串都是一种有序序列,而集合 `set` 是一种无序的序列。\n",
1164 "\n",
1165 "因为集合是无序的,所以当集合中存在两个同样的元素的时候,**Python** 只会保存其中的一个(唯一性);同时为了确保其中不包含同样的元素,集合中放入的元素只能是不可变的对象(确定性)。"
1166 ]
1167 },
1168 {
1169 "cell_type": "markdown",
1170 "metadata": {},
1171 "source": [
1172 "可以用`set()`函数来显示的生成空集合:"
1173 ]
1174 },
1175 {
1176 "cell_type": "code",
1177 "execution_count": null,
1178 "metadata": {},
1179 "outputs": [],
1180 "source": [
1181 "a = set()\n",
1182 "type(a)"
1183 ]
1184 },
1185 {
1186 "cell_type": "markdown",
1187 "metadata": {},
1188 "source": [
1189 "也可以使用一个列表来初始化一个集合:"
1190 ]
1191 },
1192 {
1193 "cell_type": "code",
1194 "execution_count": null,
1195 "metadata": {},
1196 "outputs": [],
1197 "source": [
1198 "a = set([1, 2, 3, 1])\n",
1199 "a"
1200 ]
1201 },
1202 {
1203 "cell_type": "markdown",
1204 "metadata": {},
1205 "source": [
1206 "集合会自动**去除重复元素** `1`。\n",
1207 "\n",
1208 "可以看到,集合中的元素是用大括号`{}`包含起来的,这意味着可以用`{}`的形式来创建集合:"
1209 ]
1210 },
1211 {
1212 "cell_type": "code",
1213 "execution_count": null,
1214 "metadata": {},
1215 "outputs": [],
1216 "source": [
1217 "a = {1, 2, 3, 1}\n",
1218 "a"
1219 ]
1220 },
1221 {
1222 "cell_type": "markdown",
1223 "metadata": {},
1224 "source": [
1225 "但是创建空集合的时候只能用`set`来创建,因为在 **Python** 中`{}`创建的是一个空的字典:"
1226 ]
1227 },
1228 {
1229 "cell_type": "code",
1230 "execution_count": null,
1231 "metadata": {},
1232 "outputs": [],
1233 "source": [
1234 "s = {}\n",
1235 "type(s)"
1236 ]
1237 },
1238 {
1239 "cell_type": "markdown",
1240 "metadata": {},
1241 "source": [
1242 "## 1.9 判断语句"
1243 ]
1244 },
1245 {
1246 "cell_type": "markdown",
1247 "metadata": {},
1248 "source": [
1249 "### 基本用法\n",
1250 "\n",
1251 "判断,基于一定的条件,决定是否要执行特定的一段代码,例如判断一个数是不是正数:"
1252 ]
1253 },
1254 {
1255 "cell_type": "code",
1256 "execution_count": null,
1257 "metadata": {},
1258 "outputs": [],
1259 "source": [
1260 "x = 0.5\n",
1261 "if x > 0:\n",
1262 " print (\"Hey!\")\n",
1263 " print (\"x is positive\")"
1264 ]
1265 },
1266 {
1267 "cell_type": "markdown",
1268 "metadata": {},
1269 "source": [
1270 "在这里,如果 `x > 0` 为 `False`即`x ≤ 0` ,那么程序将不会执行两条 `print` 语句。\n",
1271 "\n",
1272 "虽然都是用 `if` 关键词定义判断,但与 **C,Java** 等语言不同,**Python**不使用 `{}` 将 `if` 语句控制的区域包含起来。**Python** 使用的是缩进方法。同时,也不需要用 `()` 将判断条件括起来。\n",
1273 "\n",
1274 "上面例子中的这两条语句:\n",
1275 "\n",
1276 "```python\n",
1277 "print(\"Hey!\") \n",
1278 "print(\"x is positive\")\n",
1279 "\n",
1280 "```\n",
1281 "\n",
1282 "就叫做一个代码块,同一个代码块使用同样的缩进值,它们组成了这条 `if` 语句的主体。\n",
1283 "\n",
1284 "不同的缩进值表示不同的代码块,例如:\n",
1285 "\n",
1286 "`x > 0` 时:"
1287 ]
1288 },
1289 {
1290 "cell_type": "code",
1291 "execution_count": null,
1292 "metadata": {},
1293 "outputs": [],
1294 "source": [
1295 "x = 0.5\n",
1296 "if x > 0:\n",
1297 " print (\"Hey!\")\n",
1298 " print (\"x is positive\")\n",
1299 " print (\"This is still part of the block\")\n",
1300 "print (\"This isn't part of the block, and will always print.\")"
1301 ]
1302 },
1303 {
1304 "cell_type": "markdown",
1305 "metadata": {},
1306 "source": [
1307 "`x < 0` 时:"
1308 ]
1309 },
1310 {
1311 "cell_type": "code",
1312 "execution_count": null,
1313 "metadata": {},
1314 "outputs": [],
1315 "source": [
1316 "x = -0.5\n",
1317 "if x > 0:\n",
1318 " print (\"Hey!\")\n",
1319 " print (\"x is positive\")\n",
1320 " print (\"This is still part of the block\")\n",
1321 "print (\"This isn't part of the block, and will always print.\")\n"
1322 ]
1323 },
1324 {
1325 "cell_type": "markdown",
1326 "metadata": {},
1327 "source": [
1328 "在这两个例子中,最后一句并不是 `if` 语句中的内容,所以不管条件满不满足,它都会被执行。\n",
1329 "\n",
1330 "一个完整的 `if` 结构通常如下所示(注意:条件后的 `:` 是必须要的,缩进值需要一样):\n",
1331 " \n",
1332 " if <condition 1>:\n",
1333 " <statement 1>\n",
1334 " <statement 2>\n",
1335 " elif <condition 2>: \n",
1336 " <statements>\n",
1337 " else:\n",
1338 " <statements>\n",
1339 "\n",
1340 "当条件 1 被满足时,执行 `if` 下面的语句,当条件 1 不满足的时候,转到 `elif` ,看它的条件 2 满不满足,满足执行 `elif` 下面的语句,不满足则执行 `else` 下面的语句。\n",
1341 "\n",
1342 "对于上面的例子进行扩展:"
1343 ]
1344 },
1345 {
1346 "cell_type": "code",
1347 "execution_count": null,
1348 "metadata": {},
1349 "outputs": [],
1350 "source": [
1351 "x = 0\n",
1352 "if x > 0:\n",
1353 " print (\"x is positive\")\n",
1354 "elif x == 0:\n",
1355 " print (\"x is zero\")\n",
1356 "else:\n",
1357 " print (\"x is negative\")\n"
1358 ]
1359 },
1360 {
1361 "cell_type": "markdown",
1362 "metadata": {},
1363 "source": [
1364 "`elif` 的个数没有限制,可以是1个或者多个,也可以没有。\n",
1365 "\n",
1366 "`else` 最多只有1个,也可以没有。\n",
1367 "\n",
1368 "可以使用 `and` , `or` , `not` 等关键词结合多个判断条件:"
1369 ]
1370 },
1371 {
1372 "cell_type": "code",
1373 "execution_count": null,
1374 "metadata": {},
1375 "outputs": [],
1376 "source": [
1377 "x = 10\n",
1378 "y = -5\n",
1379 "x > 0 and y < 0"
1380 ]
1381 },
1382 {
1383 "cell_type": "code",
1384 "execution_count": null,
1385 "metadata": {},
1386 "outputs": [],
1387 "source": [
1388 "not x > 0"
1389 ]
1390 },
1391 {
1392 "cell_type": "code",
1393 "execution_count": null,
1394 "metadata": {},
1395 "outputs": [],
1396 "source": [
1397 "x < 0 or y < 0"
1398 ]
1399 },
1400 {
1401 "cell_type": "markdown",
1402 "metadata": {},
1403 "source": [
1404 "这里使用这个简单的例子,假如想判断一个年份是不是闰年,按照闰年的定义,这里只需要判断这个年份是不是能被 `4` 整除,但是不能被 `100` 整除,或者正好被 `400` 整除:"
1405 ]
1406 },
1407 {
1408 "cell_type": "code",
1409 "execution_count": null,
1410 "metadata": {},
1411 "outputs": [],
1412 "source": [
1413 "year = 1900\n",
1414 "if year % 400 == 0:\n",
1415 " print (\"This is a leap year!\")\n",
1416 "# 两个条件都满足才执行\n",
1417 "elif year % 4 == 0 and year % 100 != 0:\n",
1418 " print (\"This is a leap year!\")\n",
1419 "else:\n",
1420 " print (\"This is not a leap year.\")\n"
1421 ]
1422 },
1423 {
1424 "cell_type": "markdown",
1425 "metadata": {},
1426 "source": [
1427 "### 判断条件为 `False` 情况总结:"
1428 ]
1429 },
1430 {
1431 "cell_type": "markdown",
1432 "metadata": {},
1433 "source": [
1434 "**Python** 不仅仅可以使用布尔型变量作为条件,它可以直接在 `if` 中使用任何表达式作为条件:\n",
1435 "\n",
1436 "大部分表达式的值都会被当作 `True`,但以下表达式值会被当作 `False`:\n",
1437 "\n",
1438 "- False\n",
1439 "- None\n",
1440 "- 0\n",
1441 "- 空字符串,空列表,空字典,空集合"
1442 ]
1443 },
1444 {
1445 "cell_type": "code",
1446 "execution_count": null,
1447 "metadata": {},
1448 "outputs": [],
1449 "source": [
1450 "mylist = [3, 1, 4, 1, 5, 9]\n",
1451 "if mylist:\n",
1452 " print (\"The first element is:\", mylist[0])\n",
1453 "else:\n",
1454 " print (\"There is no first element.\")"
1455 ]
1456 },
1457 {
1458 "cell_type": "markdown",
1459 "metadata": {},
1460 "source": [
1461 "修改为空列表:"
1462 ]
1463 },
1464 {
1465 "cell_type": "code",
1466 "execution_count": null,
1467 "metadata": {},
1468 "outputs": [],
1469 "source": [
1470 "mylist = []\n",
1471 "if mylist:\n",
1472 " print (\"The first element is:\", mylist[0])\n",
1473 "else:\n",
1474 " print (\"There is no first element.\")\n"
1475 ]
1476 },
1477 {
1478 "cell_type": "markdown",
1479 "metadata": {},
1480 "source": [
1481 "## 1.10 循环\n",
1482 "\n",
1483 "循环的作用在于将一段代码重复执行多次。\n",
1484 "\n",
1485 "### `while` 循环"
1486 ]
1487 },
1488 {
1489 "cell_type": "markdown",
1490 "metadata": {},
1491 "source": [
1492 " while <condition>:\n",
1493 " <statesments>\n",
1494 "**Python** 会循环执行`<statesments>`,直到`<condition>`不满足为止。\n",
1495 "\n",
1496 "例如,计算数字`0`到`100`的和:"
1497 ]
1498 },
1499 {
1500 "cell_type": "code",
1501 "execution_count": null,
1502 "metadata": {},
1503 "outputs": [],
1504 "source": [
1505 "i = 0\n",
1506 "\n",
1507 "# 求和结果\n",
1508 "total = 0\n",
1509 "\n",
1510 "# 循环条件\n",
1511 "while i < 100:\n",
1512 " \n",
1513 " # 求和累加\n",
1514 " total += i\n",
1515 " \n",
1516 " # 变量递增\n",
1517 " i += 1\n",
1518 " \n",
1519 "# 打印结果\n",
1520 "print (total)"
1521 ]
1522 },
1523 {
1524 "cell_type": "markdown",
1525 "metadata": {},
1526 "source": [
1527 "之前提到,空容器会被当成 `False` ,因此可以用 `while` 循环来读取容器中的所有元素:"
1528 ]
1529 },
1530 {
1531 "cell_type": "code",
1532 "execution_count": null,
1533 "metadata": {},
1534 "outputs": [],
1535 "source": [
1536 "plays = ['Hamlet', 'Macbeth', 'King Lear']\n",
1537 "while plays:\n",
1538 " play = plays.pop()\n",
1539 " print ('Perform', play)"
1540 ]
1541 },
1542 {
1543 "cell_type": "markdown",
1544 "metadata": {},
1545 "source": [
1546 "循环每次从 `plays` 中弹出一个元素,一直到 `plays` 为空为止。"
1547 ]
1548 },
1549 {
1550 "cell_type": "markdown",
1551 "metadata": {},
1552 "source": [
1553 "### `for` 循环 \n",
1554 "\n",
1555 " for <variable> in <sequence>:\n",
1556 " <indented block of code>\n",
1557 "\n",
1558 "`for` 循环会遍历完`<sequence>`中所有元素为止\n",
1559 "\n",
1560 "上一个例子可以改写成如下形式:"
1561 ]
1562 },
1563 {
1564 "cell_type": "code",
1565 "execution_count": null,
1566 "metadata": {},
1567 "outputs": [],
1568 "source": [
1569 "plays = ['Hamlet', 'Macbeth', 'King Lear']\n",
1570 "for play in plays:\n",
1571 " print ('Perform', play)"
1572 ]
1573 },
1574 {
1575 "cell_type": "markdown",
1576 "metadata": {},
1577 "source": [
1578 "使用 `for` 循环时,注意尽量不要改变 `plays` 的值,否则可能会产生意想不到的结果。\n",
1579 "\n",
1580 "之前的求和也可以通过 `for` 循环来实现:"
1581 ]
1582 },
1583 {
1584 "cell_type": "code",
1585 "execution_count": null,
1586 "metadata": {},
1587 "outputs": [],
1588 "source": [
1589 "total = 0\n",
1590 "for i in range(100):\n",
1591 " total += i\n",
1592 "print (total)"
1593 ]
1594 },
1595 {
1596 "cell_type": "markdown",
1597 "metadata": {},
1598 "source": [
1599 "### `continue` 语句\n",
1600 "\n",
1601 "遇到 `continue` 的时候,程序会返回到循环的最开始重新执行。\n",
1602 "\n",
1603 "例如在循环中忽略一些特定的值:"
1604 ]
1605 },
1606 {
1607 "cell_type": "code",
1608 "execution_count": null,
1609 "metadata": {},
1610 "outputs": [],
1611 "source": [
1612 "values = [7, 6, 4, 7, 19, 2, 1]\n",
1613 "for i in values:\n",
1614 " if i % 2 != 0:\n",
1615 " # 忽略奇数\n",
1616 " continue\n",
1617 " print (i/2)"
1618 ]
1619 },
1620 {
1621 "cell_type": "markdown",
1622 "metadata": {},
1623 "source": [
1624 "### `break` 语句 \n",
1625 "\n",
1626 "遇到 `break` 的时候,程序会跳出循环,不管循环条件是不是满足:"
1627 ]
1628 },
1629 {
1630 "cell_type": "code",
1631 "execution_count": null,
1632 "metadata": {},
1633 "outputs": [],
1634 "source": [
1635 "command_list = ['start',\n",
1636 " 'process',\n",
1637 " 'process',\n",
1638 " 'process',\n",
1639 " 'stop',\n",
1640 " 'start',\n",
1641 " 'process',\n",
1642 " 'stop']\n",
1643 "while command_list:\n",
1644 " command = command_list.pop(0)\n",
1645 " if command == 'stop':\n",
1646 " break\n",
1647 " print(command)"
1648 ]
1649 },
1650 {
1651 "cell_type": "markdown",
1652 "metadata": {},
1653 "source": [
1654 "在遇到第一个 `'stop'` 之后,程序跳出循环。"
1655 ]
1656 },
1657 {
1658 "cell_type": "markdown",
1659 "metadata": {},
1660 "source": [
1661 "### `else` 语句\n",
1662 "\n",
1663 "与 `if` 一样, `while` 和 `for` 循环后面也可以跟着 `else` 语句。\n",
1664 "\n",
1665 "- 当循环正常结束时,循环条件不满足, `else` 被执行;\n",
1666 "- 当循环被 `break` 结束时,循环条件仍然满足, `else` 不执行。\n",
1667 "\n",
1668 "不执行 `else` 语句:"
1669 ]
1670 },
1671 {
1672 "cell_type": "code",
1673 "execution_count": null,
1674 "metadata": {},
1675 "outputs": [],
1676 "source": [
1677 "values = [7, 6, 4, 7, 19, 2, 1]\n",
1678 "for x in values:\n",
1679 " if x <= 10:\n",
1680 " print ('Found:', x)\n",
1681 " break\n",
1682 "else:\n",
1683 " print ('All values greater than 10')\n"
1684 ]
1685 },
1686 {
1687 "cell_type": "markdown",
1688 "metadata": {},
1689 "source": [
1690 "执行 `else` 语句:"
1691 ]
1692 },
1693 {
1694 "cell_type": "code",
1695 "execution_count": null,
1696 "metadata": {},
1697 "outputs": [],
1698 "source": [
1699 "values = [11, 12, 13, 100]\n",
1700 "for x in values:\n",
1701 " if x <= 10:\n",
1702 " print ('Found:', x)\n",
1703 " break\n",
1704 "else:\n",
1705 " print ('All values greater than 10')\n"
1706 ]
1707 }
1708 ],
1709 "metadata": {
1710 "kernelspec": {
1711 "display_name": "Python 3",
1712 "language": "python",
1713 "name": "python3"
1714 },
1715 "language_info": {
1716 "codemirror_mode": {
1717 "name": "ipython",
1718 "version": 3
1719 },
1720 "file_extension": ".py",
1721 "mimetype": "text/x-python",
1722 "name": "python",
1723 "nbconvert_exporter": "python",
1724 "pygments_lexer": "ipython3",
1725 "version": "3.5.2"
1726 }
1727 },
1728 "nbformat": 4,
1729 "nbformat_minor": 2
1730 }
0 {
1 "cells": [
2 {
3 "cell_type": "markdown",
4 "metadata": {},
5 "source": [
6 "# 2. Python 进阶"
7 ]
8 },
9 {
10 "cell_type": "markdown",
11 "metadata": {},
12 "source": [
13 "## 2.1 列表推导式"
14 ]
15 },
16 {
17 "cell_type": "markdown",
18 "metadata": {},
19 "source": [
20 "循环可以用来生成列表:"
21 ]
22 },
23 {
24 "cell_type": "code",
25 "execution_count": null,
26 "metadata": {},
27 "outputs": [],
28 "source": [
29 "a = [x for x in range(4)]\n",
30 "a"
31 ]
32 },
33 {
34 "cell_type": "markdown",
35 "metadata": {},
36 "source": [
37 "在循环的过程中使用 `if`:"
38 ]
39 },
40 {
41 "cell_type": "code",
42 "execution_count": null,
43 "metadata": {},
44 "outputs": [],
45 "source": [
46 "a = [x for x in range(3,10) if x % 2 == 0]\n",
47 "a"
48 ]
49 },
50 {
51 "cell_type": "markdown",
52 "metadata": {},
53 "source": [
54 "2 个`for` 循环:"
55 ]
56 },
57 {
58 "cell_type": "code",
59 "execution_count": null,
60 "metadata": {},
61 "outputs": [],
62 "source": [
63 "a = [(x,y) for x in range(1,3) for y in range(1,3)]\n",
64 "a"
65 ]
66 },
67 {
68 "cell_type": "markdown",
69 "metadata": {},
70 "source": [
71 "3 个 `for` 循环:"
72 ]
73 },
74 {
75 "cell_type": "code",
76 "execution_count": null,
77 "metadata": {},
78 "outputs": [],
79 "source": [
80 "a = [(x,y,z) for x in range(1,3) for y in range(1,3) for z in range(1,3)]\n",
81 "a"
82 ]
83 },
84 {
85 "cell_type": "markdown",
86 "metadata": {},
87 "source": [
88 "也可以使用推导式生成集合和字典: \n",
89 "**字典推导式**:"
90 ]
91 },
92 {
93 "cell_type": "code",
94 "execution_count": null,
95 "metadata": {},
96 "outputs": [],
97 "source": [
98 "values = [10, 21, 4, 7, 12]\n",
99 "square_dict = {x: x**2 for x in values if x <= 10}\n",
100 "print(square_dict)"
101 ]
102 },
103 {
104 "cell_type": "markdown",
105 "metadata": {},
106 "source": [
107 "**集合推导式**:"
108 ]
109 },
110 {
111 "cell_type": "code",
112 "execution_count": null,
113 "metadata": {},
114 "outputs": [],
115 "source": [
116 "values = [10, 21, 4, 7, 12]\n",
117 "square_set = {x**2 for x in values if x <= 10}\n",
118 "print(square_set)"
119 ]
120 },
121 {
122 "cell_type": "markdown",
123 "metadata": {},
124 "source": [
125 "## 2.2 函数\n",
126 "\n",
127 "### 定义函数\n",
128 "\n",
129 "函数`function`,通常接收输入参数,并有返回值。\n",
130 "\n",
131 "它负责完成某项特定任务,而且相较于其他代码,具备相对的独立性。"
132 ]
133 },
134 {
135 "cell_type": "code",
136 "execution_count": null,
137 "metadata": {},
138 "outputs": [],
139 "source": [
140 "def add(x, y):\n",
141 " \"\"\"Add two numbers\"\"\"\n",
142 " a = x + y\n",
143 " return a"
144 ]
145 },
146 {
147 "cell_type": "markdown",
148 "metadata": {},
149 "source": [
150 "函数通常有一下几个特征:\n",
151 "- 使用 `def` 关键词来定义一个函数。\n",
152 "- `def` 后面是函数的名称,括号中是函数的参数,不同的参数用 `,` 隔开, `def foo():` 的形式是必须要有的,参数可以为空;\n",
153 "- 使用缩进来划分函数的内容;\n",
154 "- `docstring` 用 `\"\"\"` 包含的字符串,用来解释函数的用途,可省略;\n",
155 "- `return` 返回特定的值,如果省略,返回 `None` 。"
156 ]
157 },
158 {
159 "cell_type": "markdown",
160 "metadata": {},
161 "source": [
162 "### 使用函数\n",
163 "\n",
164 "使用函数时,只需要将参数换成特定的值传给函数。\n",
165 "\n",
166 "**Python** 并没有限定参数的类型,因此可以使用不同的参数类型:"
167 ]
168 },
169 {
170 "cell_type": "code",
171 "execution_count": null,
172 "metadata": {},
173 "outputs": [],
174 "source": [
175 "print (add(2, 3))\n",
176 "print (add('foo', 'bar'))"
177 ]
178 },
179 {
180 "cell_type": "markdown",
181 "metadata": {},
182 "source": [
183 "在这个例子中,如果传入的两个参数不可以相加,那么 **Python** 会将报错:"
184 ]
185 },
186 {
187 "cell_type": "code",
188 "execution_count": null,
189 "metadata": {},
190 "outputs": [],
191 "source": [
192 "print (add(2, \"foo\"))"
193 ]
194 },
195 {
196 "cell_type": "markdown",
197 "metadata": {},
198 "source": [
199 "如果传入的参数数目与实际不符合,也会报错:"
200 ]
201 },
202 {
203 "cell_type": "code",
204 "execution_count": null,
205 "metadata": {},
206 "outputs": [],
207 "source": [
208 "print(add(1, 2, 3))"
209 ]
210 },
211 {
212 "cell_type": "markdown",
213 "metadata": {},
214 "source": [
215 "传入参数时,**Python** 提供了两种选项,第一种是上面使用的按照位置传入参数,另一种则是使用关键词模式,显式地指定参数的值:"
216 ]
217 },
218 {
219 "cell_type": "code",
220 "execution_count": null,
221 "metadata": {},
222 "outputs": [],
223 "source": [
224 "print (add(x=2, y=3))\n",
225 "print (add(y=\"foo\", x=\"bar\"))"
226 ]
227 },
228 {
229 "cell_type": "markdown",
230 "metadata": {},
231 "source": [
232 "可以混合这两种模式:"
233 ]
234 },
235 {
236 "cell_type": "code",
237 "execution_count": null,
238 "metadata": {},
239 "outputs": [],
240 "source": [
241 "print (add(2, y=3))"
242 ]
243 },
244 {
245 "cell_type": "markdown",
246 "metadata": {},
247 "source": [
248 "### 设定参数默认值\n",
249 "\n",
250 "可以在函数定义的时候给参数设定默认值,例如:"
251 ]
252 },
253 {
254 "cell_type": "code",
255 "execution_count": null,
256 "metadata": {},
257 "outputs": [],
258 "source": [
259 "def quad(x, a=1, b=0, c=0):\n",
260 " return a*x**2 + b*x + c"
261 ]
262 },
263 {
264 "cell_type": "markdown",
265 "metadata": {},
266 "source": [
267 "可以省略有默认值的参数:"
268 ]
269 },
270 {
271 "cell_type": "code",
272 "execution_count": null,
273 "metadata": {},
274 "outputs": [],
275 "source": [
276 "print (quad(2.0))"
277 ]
278 },
279 {
280 "cell_type": "markdown",
281 "metadata": {},
282 "source": [
283 "可以修改参数的默认值:"
284 ]
285 },
286 {
287 "cell_type": "code",
288 "execution_count": null,
289 "metadata": {},
290 "outputs": [],
291 "source": [
292 "print (quad(2.0, b=3))"
293 ]
294 },
295 {
296 "cell_type": "code",
297 "execution_count": null,
298 "metadata": {},
299 "outputs": [],
300 "source": [
301 "print (quad(2.0, 2, c=4))"
302 ]
303 },
304 {
305 "cell_type": "markdown",
306 "metadata": {},
307 "source": [
308 "这里混合了位置和指定两种参数传入方式,第二个 2 是传给 `a` 的。\n",
309 "\n",
310 "注意,在使用混合语法时,要注意不能给同一个值赋值多次,否则会报错,例如:"
311 ]
312 },
313 {
314 "cell_type": "code",
315 "execution_count": null,
316 "metadata": {},
317 "outputs": [],
318 "source": [
319 "print (quad(2.0, 2, a=2))"
320 ]
321 },
322 {
323 "cell_type": "markdown",
324 "metadata": {},
325 "source": [
326 "### 接收不定长参数\n",
327 "\n",
328 "使用如下方法,可以使函数接受不定数目的参数:"
329 ]
330 },
331 {
332 "cell_type": "code",
333 "execution_count": null,
334 "metadata": {},
335 "outputs": [],
336 "source": [
337 "def add(x, *args):\n",
338 " total = x\n",
339 " for arg in args:\n",
340 " total += arg\n",
341 " return total"
342 ]
343 },
344 {
345 "cell_type": "markdown",
346 "metadata": {},
347 "source": [
348 "这里,`*args` 表示参数数目不定,可以看成一个元组,把第一个参数后面的参数当作元组中的元素。"
349 ]
350 },
351 {
352 "cell_type": "code",
353 "execution_count": null,
354 "metadata": {},
355 "outputs": [],
356 "source": [
357 "print (add(1, 2, 3, 4))\n",
358 "print (add(1, 2))\n"
359 ]
360 },
361 {
362 "cell_type": "markdown",
363 "metadata": {},
364 "source": [
365 "这样定义的函数不能使用关键词传入参数,要使用关键词,可以这样:"
366 ]
367 },
368 {
369 "cell_type": "code",
370 "execution_count": null,
371 "metadata": {},
372 "outputs": [],
373 "source": [
374 "def add(x, **kwargs):\n",
375 " total = x\n",
376 " for arg, value in kwargs.items():\n",
377 " print (\"adding %s=%s\"%(arg,value))\n",
378 " total += value\n",
379 " return total\n"
380 ]
381 },
382 {
383 "cell_type": "markdown",
384 "metadata": {},
385 "source": [
386 "这里, `**kwargs` 表示参数数目不定,相当于一个字典,键和值对应于键值对。"
387 ]
388 },
389 {
390 "cell_type": "code",
391 "execution_count": null,
392 "metadata": {},
393 "outputs": [],
394 "source": [
395 "print (add(10, y=11, z=12, w=13))"
396 ]
397 },
398 {
399 "cell_type": "markdown",
400 "metadata": {},
401 "source": [
402 "再看这个例子,可以接收任意数目的位置参数和键值对参数:"
403 ]
404 },
405 {
406 "cell_type": "code",
407 "execution_count": null,
408 "metadata": {},
409 "outputs": [],
410 "source": [
411 "def foo(*args, **kwargs):\n",
412 " print (args, kwargs)\n",
413 "\n",
414 "foo(2, 3, x='bar', z=10)\n"
415 ]
416 },
417 {
418 "cell_type": "markdown",
419 "metadata": {},
420 "source": [
421 "不过要按顺序传入参数,先传入位置参数 `args` ,再传入关键词参数 `kwargs` 。"
422 ]
423 },
424 {
425 "cell_type": "markdown",
426 "metadata": {},
427 "source": [
428 "### 返回多个值\n",
429 "\n",
430 "函数可以返回多个值:"
431 ]
432 },
433 {
434 "cell_type": "code",
435 "execution_count": null,
436 "metadata": {},
437 "outputs": [],
438 "source": [
439 "def divid(a, b):\n",
440 " \"\"\"\n",
441 " 除法\n",
442 " :param a: number 被除数\n",
443 " :param b: number 除数\n",
444 " :return: 商和余数\n",
445 " \"\"\"\n",
446 " quotient = a // b\n",
447 " remainder = a % b\n",
448 " return quotient, remainder\n",
449 "\n",
450 "quotient, remainder = divid(7,4)\n",
451 "print(quotient, remainder)"
452 ]
453 },
454 {
455 "cell_type": "markdown",
456 "metadata": {},
457 "source": [
458 "事实上,**Python** 将返回的两个值变成了元组:"
459 ]
460 },
461 {
462 "cell_type": "code",
463 "execution_count": null,
464 "metadata": {},
465 "outputs": [],
466 "source": [
467 "print(divid(7,4))"
468 ]
469 },
470 {
471 "cell_type": "markdown",
472 "metadata": {},
473 "source": [
474 "因为这个元组中有两个值,所以可以使用\n",
475 "\n",
476 " quotient, remainder = divid(7,4)\n",
477 "\n",
478 "给两个值赋值。\n",
479 "\n",
480 "列表也有相似的功能:"
481 ]
482 },
483 {
484 "cell_type": "code",
485 "execution_count": null,
486 "metadata": {},
487 "outputs": [],
488 "source": [
489 "a, b, c = [1, 2, 3]\n",
490 "print (a, b, c)"
491 ]
492 },
493 {
494 "cell_type": "markdown",
495 "metadata": {},
496 "source": [
497 "事实上,不仅仅返回值可以用元组表示,也可以将参数用元组以这种方式传入:"
498 ]
499 },
500 {
501 "cell_type": "code",
502 "execution_count": null,
503 "metadata": {},
504 "outputs": [],
505 "source": [
506 "def divid(a, b):\n",
507 " \"\"\"\n",
508 " 除法\n",
509 " :param a: number 被除数\n",
510 " :param b: number 除数\n",
511 " :return: 商和余数\n",
512 " \"\"\"\n",
513 " quotient = a // b\n",
514 " remainder = a % b\n",
515 " return quotient, remainder\n",
516 "\n",
517 "z = (7,4)\n",
518 "print(divid(*z))"
519 ]
520 },
521 {
522 "cell_type": "markdown",
523 "metadata": {},
524 "source": [
525 "这里的`*`必不可少。\n",
526 "\n",
527 "事实上,还可以通过字典传入参数来执行函数:"
528 ]
529 },
530 {
531 "cell_type": "code",
532 "execution_count": null,
533 "metadata": {},
534 "outputs": [],
535 "source": [
536 "def divid(a, b):\n",
537 " \"\"\"\n",
538 " 除法\n",
539 " :param a: number 被除数\n",
540 " :param b: number 除数\n",
541 " :return: 商和余数\n",
542 " \"\"\"\n",
543 " quotient = a // b\n",
544 " remainder = a % b\n",
545 " return quotient, remainder\n",
546 "\n",
547 "z = {'a':7,'b':4}\n",
548 "print(divid(**z))"
549 ]
550 },
551 {
552 "cell_type": "markdown",
553 "metadata": {},
554 "source": [
555 "### `map` 方法生成序列"
556 ]
557 },
558 {
559 "cell_type": "markdown",
560 "metadata": {},
561 "source": [
562 "其用法为:\n",
563 " \n",
564 " map(aFun, aSeq)\n",
565 "\n",
566 "将函数 `aFun` 应用到序列 `aSeq` 上的每一个元素上,返回一个列表,不管这个序列原来是什么类型。\n",
567 "\n",
568 "事实上,根据函数参数的多少,`map` 可以接受多组序列,将其对应的元素作为参数传入函数:"
569 ]
570 },
571 {
572 "cell_type": "code",
573 "execution_count": null,
574 "metadata": {},
575 "outputs": [],
576 "source": [
577 "def divid(a, b):\n",
578 " \"\"\"\n",
579 " 除法\n",
580 " :param a: number 被除数\n",
581 " :param b: number 除数\n",
582 " :return: 商和余数\n",
583 " \"\"\"\n",
584 " quotient = a // b\n",
585 " remainder = a % b\n",
586 " return quotient, remainder\n",
587 "\n",
588 "a = (10, 6, 7)\n",
589 "b = [2, 5, 3]\n",
590 "print (list(map(divid,a,b)))"
591 ]
592 },
593 {
594 "cell_type": "markdown",
595 "metadata": {},
596 "source": [
597 "## 2.3 模块和包"
598 ]
599 },
600 {
601 "cell_type": "markdown",
602 "metadata": {},
603 "source": [
604 "### 模块\n",
605 "\n",
606 "**Python** 会将所有 `.py` 结尾的文件认定为 **Python** 代码文件,考虑下面的脚本 `ex1.py` :"
607 ]
608 },
609 {
610 "cell_type": "code",
611 "execution_count": null,
612 "metadata": {},
613 "outputs": [],
614 "source": [
615 "%%writefile ex1.py\n",
616 "\n",
617 "PI = 3.1416\n",
618 "\n",
619 "def sum(lst):\n",
620 " \"\"\"\n",
621 " 计算 lst 序列所有元素的和\n",
622 " :param lst: 序列 e.g. [1,2,3]\n",
623 " :return: lst 序列所有元素的总和\n",
624 " \"\"\"\n",
625 " \n",
626 " # 获取 lst序列第一个元素\n",
627 " tot = lst[0]\n",
628 " \n",
629 " # 循环遍历 lst 序列剩余元素\n",
630 " for value in lst[1:]:\n",
631 " tot = tot + value\n",
632 " return tot\n",
633 "\n",
634 "w = [0, 1, 2, 3]\n",
635 "print (sum(w), PI)\n",
636 "\n"
637 ]
638 },
639 {
640 "cell_type": "markdown",
641 "metadata": {},
642 "source": [
643 "可以执行它:"
644 ]
645 },
646 {
647 "cell_type": "code",
648 "execution_count": null,
649 "metadata": {},
650 "outputs": [],
651 "source": [
652 "%run ex1.py"
653 ]
654 },
655 {
656 "cell_type": "markdown",
657 "metadata": {},
658 "source": [
659 "这个脚本可以当作一个模块,可以使用`import`关键词加载并执行它(这里要求`ex1.py`在当前工作目录):"
660 ]
661 },
662 {
663 "cell_type": "code",
664 "execution_count": null,
665 "metadata": {},
666 "outputs": [],
667 "source": [
668 "import ex1"
669 ]
670 },
671 {
672 "cell_type": "code",
673 "execution_count": null,
674 "metadata": {},
675 "outputs": [],
676 "source": [
677 "ex1"
678 ]
679 },
680 {
681 "cell_type": "markdown",
682 "metadata": {},
683 "source": [
684 "在导入时,**Python** 会执行一遍模块中的所有内容。\n",
685 "\n",
686 "`ex1.py` 中所有的变量都被载入了当前环境中,不过要使用\n",
687 "\n",
688 " ex1.变量名\n",
689 "\n",
690 "的方法来查看或者修改这些变量:"
691 ]
692 },
693 {
694 "cell_type": "code",
695 "execution_count": null,
696 "metadata": {},
697 "outputs": [],
698 "source": [
699 "print (ex1.PI)"
700 ]
701 },
702 {
703 "cell_type": "code",
704 "execution_count": null,
705 "metadata": {},
706 "outputs": [],
707 "source": [
708 "ex1.PI = 3.141592653\n",
709 "print (ex1.PI)"
710 ]
711 },
712 {
713 "cell_type": "markdown",
714 "metadata": {},
715 "source": [
716 "还可以用\n",
717 "\n",
718 " ex1.函数名\n",
719 "\n",
720 "调用模块里面的函数:"
721 ]
722 },
723 {
724 "cell_type": "code",
725 "execution_count": null,
726 "metadata": {},
727 "outputs": [],
728 "source": [
729 "print (ex1.sum([2, 3, 4]))"
730 ]
731 },
732 {
733 "cell_type": "markdown",
734 "metadata": {},
735 "source": [
736 "为了提高效率,**Python** 只会载入模块一次,已经载入的模块再次载入时,**Python** 并不会真正执行载入操作,哪怕模块的内容已经改变。\n",
737 "\n",
738 "例如,这里重新导入 `ex1` 时,并不会执行 `ex1.py` 中的 `print` 语句:"
739 ]
740 },
741 {
742 "cell_type": "code",
743 "execution_count": null,
744 "metadata": {},
745 "outputs": [],
746 "source": [
747 "import ex1"
748 ]
749 },
750 {
751 "cell_type": "markdown",
752 "metadata": {},
753 "source": [
754 "需要重新导入模块时,可以使用 `reload` 强制重新载入它,例如:"
755 ]
756 },
757 {
758 "cell_type": "code",
759 "execution_count": null,
760 "metadata": {},
761 "outputs": [],
762 "source": [
763 "from imp import reload\n",
764 "reload(ex1)"
765 ]
766 },
767 {
768 "cell_type": "markdown",
769 "metadata": {},
770 "source": [
771 "删除之前生成的文件:"
772 ]
773 },
774 {
775 "cell_type": "code",
776 "execution_count": null,
777 "metadata": {},
778 "outputs": [],
779 "source": [
780 "import os\n",
781 "os.remove('ex1.py')\n"
782 ]
783 },
784 {
785 "cell_type": "markdown",
786 "metadata": {},
787 "source": [
788 "### `__name__` 属性\n",
789 "\n",
790 "有时候我们想将一个 `.py` 文件既当作脚本,又能当作模块用,这个时候可以使用 `__name__` 这个属性。\n",
791 "\n",
792 "只有当文件被当作脚本执行的时候, `__name__`的值才会是 `'__main__'`,所以我们可以:\n"
793 ]
794 },
795 {
796 "cell_type": "code",
797 "execution_count": null,
798 "metadata": {},
799 "outputs": [],
800 "source": [
801 "%%writefile ex2.py\n",
802 "\n",
803 "PI = 3.1416\n",
804 "\n",
805 "def sum(lst):\n",
806 " \"\"\" Sum the values in a list\n",
807 " \"\"\"\n",
808 " tot = 0\n",
809 " for value in lst:\n",
810 " tot = tot + value\n",
811 " return tot\n",
812 "\n",
813 "def add(x, y):\n",
814 " \" Add two values.\"\n",
815 " a = x + y\n",
816 " return a\n",
817 "\n",
818 "def test():\n",
819 " w = [0,1,2,3]\n",
820 " assert(sum(w) == 6)\n",
821 " print ('test passed.')\n",
822 "\n",
823 "if __name__ == '__main__':\n",
824 " test()\n"
825 ]
826 },
827 {
828 "cell_type": "markdown",
829 "metadata": {},
830 "source": [
831 "运行文件:"
832 ]
833 },
834 {
835 "cell_type": "code",
836 "execution_count": null,
837 "metadata": {},
838 "outputs": [],
839 "source": [
840 "%run ex2.py"
841 ]
842 },
843 {
844 "cell_type": "markdown",
845 "metadata": {},
846 "source": [
847 "当作模块导入, `test()` 不会执行:"
848 ]
849 },
850 {
851 "cell_type": "code",
852 "execution_count": null,
853 "metadata": {},
854 "outputs": [],
855 "source": [
856 "import ex2"
857 ]
858 },
859 {
860 "cell_type": "markdown",
861 "metadata": {},
862 "source": [
863 "但是可以使用其中的变量:"
864 ]
865 },
866 {
867 "cell_type": "code",
868 "execution_count": null,
869 "metadata": {},
870 "outputs": [],
871 "source": [
872 "ex2.PI\n"
873 ]
874 },
875 {
876 "cell_type": "markdown",
877 "metadata": {},
878 "source": [
879 "引入模块时可以为它设置一个别名让使用更方便:"
880 ]
881 },
882 {
883 "cell_type": "code",
884 "execution_count": null,
885 "metadata": {},
886 "outputs": [],
887 "source": [
888 "import ex2 as e2\n",
889 "e2.PI"
890 ]
891 },
892 {
893 "cell_type": "markdown",
894 "metadata": {},
895 "source": [
896 "### 其它导入方法\n",
897 "\n",
898 "可以从模块中导入变量:"
899 ]
900 },
901 {
902 "cell_type": "code",
903 "execution_count": null,
904 "metadata": {},
905 "outputs": [],
906 "source": [
907 "from ex2 import add, PI"
908 ]
909 },
910 {
911 "cell_type": "markdown",
912 "metadata": {},
913 "source": [
914 "使用 `from` 后,可以直接使用 `add` , `PI`:"
915 ]
916 },
917 {
918 "cell_type": "code",
919 "execution_count": null,
920 "metadata": {},
921 "outputs": [],
922 "source": [
923 "add(2, 3)"
924 ]
925 },
926 {
927 "cell_type": "markdown",
928 "metadata": {},
929 "source": [
930 "或者使用 `*` 导入所有变量:"
931 ]
932 },
933 {
934 "cell_type": "code",
935 "execution_count": null,
936 "metadata": {},
937 "outputs": [],
938 "source": [
939 "from ex2 import *\n",
940 "add(3, 4.5)"
941 ]
942 },
943 {
944 "cell_type": "markdown",
945 "metadata": {},
946 "source": [
947 "这种导入方法不是很提倡,因为如果你不确定导入的都有哪些,可能覆盖一些已有的函数。\n",
948 "\n",
949 "删除文件:"
950 ]
951 },
952 {
953 "cell_type": "code",
954 "execution_count": null,
955 "metadata": {},
956 "outputs": [],
957 "source": [
958 "import os\n",
959 "os.remove('ex2.py')"
960 ]
961 },
962 {
963 "cell_type": "markdown",
964 "metadata": {},
965 "source": [
966 "### 包\n",
967 "\n",
968 "假设我们有这样的一个文件夹:\n",
969 "\n",
970 "foo/\n",
971 "- `__init__.py` \n",
972 "- `bar.py` (defines func)\n",
973 "- `baz.py` (defines zap)\n",
974 "\n",
975 "这意味着 `foo` 是一个包,我们可以这样导入其中的内容:\n",
976 "\n",
977 "```python \n",
978 "\n",
979 "from foo.bar import func\n",
980 "from foo.baz import zap\n",
981 "\n",
982 "```\n",
983 "\n",
984 "`bar` 和 `baz` 都是 `foo` 文件夹下的 `.py` 文件。\n",
985 "\n",
986 "导入包要求:\n",
987 "- 文件夹 `foo` 在 **Python** 的搜索路径中\n",
988 "- `__init__.py` 表示 `foo` 是一个包,它可以是个空文件。"
989 ]
990 },
991 {
992 "cell_type": "markdown",
993 "metadata": {},
994 "source": [
995 "## 2.4 异常\n",
996 "\n",
997 "写代码的时候,出现错误不可避免,即使代码语法没有问题,也可能遇到其它问题。"
998 ]
999 },
1000 {
1001 "cell_type": "markdown",
1002 "metadata": {},
1003 "source": [
1004 "看下面这段代码:\n",
1005 "\n",
1006 "```python \n",
1007 "\n",
1008 "import math\n",
1009 "\n",
1010 "while True:\n",
1011 " text = input('> ')\n",
1012 " if text[0] == 'q':\n",
1013 " break\n",
1014 " x = float(text)\n",
1015 " y = math.log10(x)\n",
1016 " print (\"log10({0}) = {1}\".format(x, y))\n",
1017 "```\n",
1018 "\n",
1019 "这段代码接收命令行的输入,当输入为数字时,计算它的对数并输出,直到输入值为 `q` 为止。\n",
1020 "\n",
1021 "乍看没什么问题,然而当我们输入 0 或者负数时:"
1022 ]
1023 },
1024 {
1025 "cell_type": "code",
1026 "execution_count": null,
1027 "metadata": {},
1028 "outputs": [],
1029 "source": [
1030 "import math\n",
1031 "\n",
1032 "while True:\n",
1033 " text = input('> ')\n",
1034 " if text[0] == 'q':\n",
1035 " break\n",
1036 " x = float(text)\n",
1037 " y = math.log10(x)\n",
1038 " print (\"log10({0}) = {1}\".format(x, y))\n"
1039 ]
1040 },
1041 {
1042 "cell_type": "markdown",
1043 "metadata": {},
1044 "source": [
1045 "`log10` 函数会报错,因为不能接受非正值。\n",
1046 "\n",
1047 "一旦报错,程序就会停止执行。如果不希望程序停止执行,并且想要捕捉异常,那么我们可以按照 `try/except` 语句。"
1048 ]
1049 },
1050 {
1051 "cell_type": "markdown",
1052 "metadata": {},
1053 "source": [
1054 "```python\n",
1055 "\n",
1056 "import math\n",
1057 "\n",
1058 "while True:\n",
1059 " try:\n",
1060 " text = input('> ')\n",
1061 " if text[0] == 'q':\n",
1062 " break\n",
1063 " x = float(text)\n",
1064 " y = math.log10(x)\n",
1065 " print \"log10({0}) = {1}\".format(x, y)\n",
1066 " except ValueError:\n",
1067 " print (\"the value must be greater than 0\")\n",
1068 "```"
1069 ]
1070 },
1071 {
1072 "cell_type": "markdown",
1073 "metadata": {},
1074 "source": [
1075 "一旦 `try` 块中的内容出现了异常,那么 `try` 块后面的内容会被忽略,**Python** 会寻找 `except` 里面有没有对应的内容,如果找到,就执行对应的块,没有则抛出这个异常。\n",
1076 "\n",
1077 "在上面的例子中,`try` 抛出的是 `ValueError`,`except` 中有对应的内容,所以这个异常被 `except` 捕捉到,程序可以继续执行:"
1078 ]
1079 },
1080 {
1081 "cell_type": "code",
1082 "execution_count": null,
1083 "metadata": {},
1084 "outputs": [],
1085 "source": [
1086 "import math\n",
1087 "\n",
1088 "while True:\n",
1089 " try:\n",
1090 " text = input('> ')\n",
1091 " if text[0] == 'q':\n",
1092 " break\n",
1093 " x = float(text)\n",
1094 " y = math.log10(x)\n",
1095 " print (\"log10({0}) = {1}\".format(x, y))\n",
1096 " except ValueError:\n",
1097 " print (\"the value must be greater than 0\")\n"
1098 ]
1099 },
1100 {
1101 "cell_type": "markdown",
1102 "metadata": {},
1103 "source": [
1104 "### 捕捉不同的错误类型"
1105 ]
1106 },
1107 {
1108 "cell_type": "markdown",
1109 "metadata": {},
1110 "source": [
1111 "``` python\n",
1112 "import math\n",
1113 "\n",
1114 "while True:\n",
1115 " try:\n",
1116 " text = input('> ')\n",
1117 " if text[0] == 'q':\n",
1118 " break\n",
1119 " x = float(text)\n",
1120 " y = 1 / math.log10(x)\n",
1121 " print \"log10({0}) = {1}\".format(x, y)\n",
1122 " except ValueError:\n",
1123 " print \"the value must be greater than 0\"\n",
1124 "```\n",
1125 "\n",
1126 "假设我们将这里的 `y` 更改为 `1 / math.log10(x)`,此时输入 `1`:"
1127 ]
1128 },
1129 {
1130 "cell_type": "code",
1131 "execution_count": null,
1132 "metadata": {},
1133 "outputs": [],
1134 "source": [
1135 "import math\n",
1136 "\n",
1137 "while True:\n",
1138 " try:\n",
1139 " text = input('> ')\n",
1140 " if text[0] == 'q':\n",
1141 " break\n",
1142 " x = float(text)\n",
1143 " y = 1 / math.log10(x)\n",
1144 " print (\"log10({0}) = {1}\".format(x, y))\n",
1145 " except ValueError:\n",
1146 " print (\"the value must be greater than 0\")\n"
1147 ]
1148 },
1149 {
1150 "cell_type": "markdown",
1151 "metadata": {},
1152 "source": [
1153 "因为我们的 `except` 里面并没有 `ZeroDivisionError`,所以会抛出这个异常,我们可以通过两种方式解决这个问题。"
1154 ]
1155 },
1156 {
1157 "cell_type": "markdown",
1158 "metadata": {},
1159 "source": [
1160 "### 捕捉所有异常\n",
1161 "\n",
1162 "将`except` 的值改成 `Exception` 类,来捕获所有的异常。"
1163 ]
1164 },
1165 {
1166 "cell_type": "code",
1167 "execution_count": null,
1168 "metadata": {},
1169 "outputs": [],
1170 "source": [
1171 "import math\n",
1172 "\n",
1173 "while True:\n",
1174 " try:\n",
1175 " text = input('> ')\n",
1176 " if text[0] == 'q':\n",
1177 " break\n",
1178 " x = float(text)\n",
1179 " y = 1 / math.log10(x)\n",
1180 " print (\"1 / log10({0}) = {1}\".format(x, y))\n",
1181 " except Exception:\n",
1182 " print (\"invalid value\")\n"
1183 ]
1184 },
1185 {
1186 "cell_type": "markdown",
1187 "metadata": {},
1188 "source": [
1189 "### 指定特定异常\n",
1190 "\n",
1191 "这里,我们把 `ZeroDivisionError` 加入 `except` 。"
1192 ]
1193 },
1194 {
1195 "cell_type": "code",
1196 "execution_count": null,
1197 "metadata": {},
1198 "outputs": [],
1199 "source": [
1200 "import math\n",
1201 "\n",
1202 "while True:\n",
1203 " try:\n",
1204 " text = input('> ')\n",
1205 " if text[0] == 'q':\n",
1206 " break\n",
1207 " x = float(text)\n",
1208 " y = 1 / math.log10(x)\n",
1209 " print (\"1 / log10({0}) = {1}\".format(x, y))\n",
1210 " except (ValueError, ZeroDivisionError):\n",
1211 " print (\"invalid value\")\n"
1212 ]
1213 },
1214 {
1215 "cell_type": "markdown",
1216 "metadata": {},
1217 "source": [
1218 "或者另加处理:"
1219 ]
1220 },
1221 {
1222 "cell_type": "code",
1223 "execution_count": null,
1224 "metadata": {},
1225 "outputs": [],
1226 "source": [
1227 "import math\n",
1228 "\n",
1229 "while True:\n",
1230 " try:\n",
1231 " text = input('> ')\n",
1232 " if text[0] == 'q':\n",
1233 " break\n",
1234 " x = float(text)\n",
1235 " y = 1 / math.log10(x)\n",
1236 " print (\"1 / log10({0}) = {1}\".format(x, y))\n",
1237 " except ValueError:\n",
1238 " print (\"the value must be greater than 0\")\n",
1239 " except ZeroDivisionError:\n",
1240 " print (\"the value must not be 1\")\n"
1241 ]
1242 },
1243 {
1244 "cell_type": "markdown",
1245 "metadata": {},
1246 "source": [
1247 "事实上,我们还可以将这两种方式结合起来,用 `Exception` 来捕捉其他的错误:"
1248 ]
1249 },
1250 {
1251 "cell_type": "code",
1252 "execution_count": null,
1253 "metadata": {},
1254 "outputs": [],
1255 "source": [
1256 "import math\n",
1257 "\n",
1258 "while True:\n",
1259 " try:\n",
1260 " text = input('> ')\n",
1261 " if text[0] == 'q':\n",
1262 " break\n",
1263 " x = float(text)\n",
1264 " y = 1 / math.log10(x)\n",
1265 " print (\"1 / log10({0}) = {1}\".format(x, y))\n",
1266 " except ValueError:\n",
1267 " print (\"the value must be greater than 0\")\n",
1268 " except ZeroDivisionError:\n",
1269 " print (\"the value must not be 1\")\n",
1270 " except Exception:\n",
1271 " print (\"unexpected error\")\n"
1272 ]
1273 },
1274 {
1275 "cell_type": "markdown",
1276 "metadata": {},
1277 "source": [
1278 "### 得到异常的具体信息\n",
1279 "\n",
1280 "在上面的例子中,当我们输入不能转换为浮点数的字符串时,它输出的是 `the value must be greater than 0`,这并没有反映出实际情况。"
1281 ]
1282 },
1283 {
1284 "cell_type": "markdown",
1285 "metadata": {},
1286 "source": [
1287 "为了得到异常的具体信息,我们将这个 `ValueError` 具体化:"
1288 ]
1289 },
1290 {
1291 "cell_type": "code",
1292 "execution_count": null,
1293 "metadata": {},
1294 "outputs": [],
1295 "source": [
1296 "import math\n",
1297 "\n",
1298 "while True:\n",
1299 " try:\n",
1300 " text = input('> ')\n",
1301 " if text[0] == 'q':\n",
1302 " break\n",
1303 " x = float(text)\n",
1304 " y = 1 / math.log10(x)\n",
1305 " print (\"1 / log10({0}) = {1}\".format(x, y))\n",
1306 " except ValueError as exc:\n",
1307 " if exc.message == \"math domain error\":\n",
1308 " print (\"the value must be greater than 0\")\n",
1309 " else:\n",
1310 " print (\"could not convert '%s' to float\" % text)\n",
1311 " except ZeroDivisionError:\n",
1312 " print (\"the value must not be 1\")\n",
1313 " except Exception as exc:\n",
1314 " print (\"unexpected error:\", exc.message)\n"
1315 ]
1316 },
1317 {
1318 "cell_type": "markdown",
1319 "metadata": {},
1320 "source": [
1321 "同时,我们也将捕获的其他异常的信息显示出来。\n",
1322 "\n",
1323 "这里,`exc.message` 显示的内容是异常对应的说明,例如\n",
1324 "\n",
1325 " ValueError: could not convert string to float: a\n",
1326 "\n",
1327 "对应的 `message` 是 \n",
1328 "\n",
1329 " could not convert string to float: a\n",
1330 "\n",
1331 "当我们使用 `except Exception` 时,会捕获所有的 `Exception` 和它派生出来的子类,但不是所有的异常都是从 `Exception` 类派生出来的,可能会出现一些不能捕获的情况,因此,更加一般的做法是使用这样的形式:\n",
1332 "\n",
1333 "```python\n",
1334 "try:\n",
1335 " pass\n",
1336 "except:\n",
1337 " pass\n",
1338 "```\n",
1339 "\n",
1340 "这样不指定异常的类型会捕获所有的异常,但是这样的形式并不推荐。"
1341 ]
1342 },
1343 {
1344 "cell_type": "markdown",
1345 "metadata": {},
1346 "source": [
1347 "### finally\n",
1348 "\n",
1349 "`try/except` 块还有一个可选的关键词 `finally`。\n",
1350 "\n",
1351 "不管 `try` 块有没有异常, `finally` 块的内容总是会被执行,而且会在抛出异常前执行,因此可以用来作为安全保证,比如确保打开的文件被关闭。"
1352 ]
1353 },
1354 {
1355 "cell_type": "code",
1356 "execution_count": null,
1357 "metadata": {},
1358 "outputs": [],
1359 "source": [
1360 "try:\n",
1361 " print (1)\n",
1362 "finally:\n",
1363 " print ('finally was called.')"
1364 ]
1365 },
1366 {
1367 "cell_type": "markdown",
1368 "metadata": {},
1369 "source": [
1370 "在抛出异常前执行:"
1371 ]
1372 },
1373 {
1374 "cell_type": "code",
1375 "execution_count": null,
1376 "metadata": {},
1377 "outputs": [],
1378 "source": [
1379 "try:\n",
1380 " print (1 / 0)\n",
1381 "finally:\n",
1382 " print ('finally was called.')"
1383 ]
1384 },
1385 {
1386 "cell_type": "markdown",
1387 "metadata": {},
1388 "source": [
1389 "如果异常被捕获了,在最后执行:"
1390 ]
1391 },
1392 {
1393 "cell_type": "code",
1394 "execution_count": null,
1395 "metadata": {},
1396 "outputs": [],
1397 "source": [
1398 "try:\n",
1399 " print (1 / 0)\n",
1400 "except ZeroDivisionError:\n",
1401 " print ('divide by 0.')\n",
1402 "finally:\n",
1403 " print ('finally was called.')\n"
1404 ]
1405 },
1406 {
1407 "cell_type": "markdown",
1408 "metadata": {},
1409 "source": [
1410 "异常的处理流程可参考下图:\n",
1411 "\n",
1412 "<img src=\"https://www.runoob.com/wp-content/uploads/2019/07/try_except_else_finally.png\" width=600px/>\n",
1413 "\n"
1414 ]
1415 },
1416 {
1417 "cell_type": "markdown",
1418 "metadata": {},
1419 "source": [
1420 "## 2.5 警告\n",
1421 "\n",
1422 "出现了一些需要让用户知道的问题,但又不想停止程序,这时候我们可以使用警告:\n",
1423 "\n",
1424 "首先导入警告模块:"
1425 ]
1426 },
1427 {
1428 "cell_type": "code",
1429 "execution_count": null,
1430 "metadata": {},
1431 "outputs": [],
1432 "source": [
1433 "import warnings"
1434 ]
1435 },
1436 {
1437 "cell_type": "markdown",
1438 "metadata": {},
1439 "source": [
1440 "在需要的地方,我们使用 `warnings` 中的 `warn` 函数:\n",
1441 "\n",
1442 " warn(msg, WarningType = UserWarning)"
1443 ]
1444 },
1445 {
1446 "cell_type": "code",
1447 "execution_count": null,
1448 "metadata": {},
1449 "outputs": [],
1450 "source": [
1451 "def month_warning(m):\n",
1452 " if not 1<= m <= 12:\n",
1453 " msg = \"month (%d) is not between 1 and 12\" % m\n",
1454 " warnings.warn(msg, RuntimeWarning)\n",
1455 "\n",
1456 "month_warning(13)\n"
1457 ]
1458 },
1459 {
1460 "cell_type": "markdown",
1461 "metadata": {},
1462 "source": [
1463 "有时候我们想要忽略特定类型的警告,可以使用 `warnings` 的 `filterwarnings` 函数:\n",
1464 "\n",
1465 " filterwarnings(action, category)\n",
1466 "\n",
1467 "将 `action` 设置为 `'ignore'` 便可以忽略特定类型的警告:"
1468 ]
1469 },
1470 {
1471 "cell_type": "code",
1472 "execution_count": null,
1473 "metadata": {},
1474 "outputs": [],
1475 "source": [
1476 "warnings.filterwarnings(action = 'ignore', category = RuntimeWarning)\n",
1477 "\n",
1478 "month_warning(13)"
1479 ]
1480 },
1481 {
1482 "cell_type": "markdown",
1483 "metadata": {},
1484 "source": [
1485 "## 2.6 文件读写\n",
1486 "\n",
1487 "写入测试文件:"
1488 ]
1489 },
1490 {
1491 "cell_type": "code",
1492 "execution_count": null,
1493 "metadata": {},
1494 "outputs": [],
1495 "source": [
1496 "%%writefile test.txt\n",
1497 "this is a test file.\n",
1498 "hello world!\n",
1499 "python is good!\n",
1500 "today is a good day.\n"
1501 ]
1502 },
1503 {
1504 "cell_type": "markdown",
1505 "metadata": {},
1506 "source": [
1507 "### 读文件\n",
1508 "\n",
1509 "使用 `open` 函数来读文件,使用文件名的字符串作为输入参数:\n"
1510 ]
1511 },
1512 {
1513 "cell_type": "code",
1514 "execution_count": null,
1515 "metadata": {},
1516 "outputs": [],
1517 "source": [
1518 "f = open('test.txt')"
1519 ]
1520 },
1521 {
1522 "cell_type": "markdown",
1523 "metadata": {},
1524 "source": [
1525 "默认以读的方式打开文件,如果文件不存在会报错:"
1526 ]
1527 },
1528 {
1529 "cell_type": "code",
1530 "execution_count": null,
1531 "metadata": {},
1532 "outputs": [],
1533 "source": [
1534 "f = open('test1.txt')"
1535 ]
1536 },
1537 {
1538 "cell_type": "markdown",
1539 "metadata": {},
1540 "source": [
1541 "可以使用 `read` 方法来读入文件中的所有内容:"
1542 ]
1543 },
1544 {
1545 "cell_type": "code",
1546 "execution_count": null,
1547 "metadata": {},
1548 "outputs": [],
1549 "source": [
1550 "text = f.read()\n",
1551 "print (text)"
1552 ]
1553 },
1554 {
1555 "cell_type": "markdown",
1556 "metadata": {},
1557 "source": [
1558 "也可以按照行读入内容,`readlines` 方法返回一个列表,每个元素代表文件中每一行的内容:"
1559 ]
1560 },
1561 {
1562 "cell_type": "code",
1563 "execution_count": null,
1564 "metadata": {},
1565 "outputs": [],
1566 "source": [
1567 "f = open('test.txt')\n",
1568 "lines = f.readlines()\n",
1569 "print (lines)\n"
1570 ]
1571 },
1572 {
1573 "cell_type": "markdown",
1574 "metadata": {},
1575 "source": [
1576 "使用完文件之后,需要将文件关闭。"
1577 ]
1578 },
1579 {
1580 "cell_type": "code",
1581 "execution_count": null,
1582 "metadata": {},
1583 "outputs": [],
1584 "source": [
1585 "f.close()"
1586 ]
1587 },
1588 {
1589 "cell_type": "markdown",
1590 "metadata": {},
1591 "source": [
1592 "事实上,我们可以将 `f` 放在一个循环中,得到它每一行的内容:"
1593 ]
1594 },
1595 {
1596 "cell_type": "code",
1597 "execution_count": null,
1598 "metadata": {},
1599 "outputs": [],
1600 "source": [
1601 "f = open('test.txt')\n",
1602 "for line in f:\n",
1603 " print (line)\n",
1604 "f.close()\n"
1605 ]
1606 },
1607 {
1608 "cell_type": "markdown",
1609 "metadata": {},
1610 "source": [
1611 "删除刚才创建的文件:"
1612 ]
1613 },
1614 {
1615 "cell_type": "code",
1616 "execution_count": null,
1617 "metadata": {},
1618 "outputs": [],
1619 "source": [
1620 "import os\n",
1621 "os.remove('test.txt')"
1622 ]
1623 },
1624 {
1625 "cell_type": "markdown",
1626 "metadata": {},
1627 "source": [
1628 "### 写文件"
1629 ]
1630 },
1631 {
1632 "cell_type": "markdown",
1633 "metadata": {},
1634 "source": [
1635 "我们使用 `open` 函数的写入模式来写文件:"
1636 ]
1637 },
1638 {
1639 "cell_type": "code",
1640 "execution_count": null,
1641 "metadata": {},
1642 "outputs": [],
1643 "source": [
1644 "f = open('myfile.txt', 'w')\n",
1645 "f.write('hello world!')\n",
1646 "f.close()"
1647 ]
1648 },
1649 {
1650 "cell_type": "markdown",
1651 "metadata": {},
1652 "source": [
1653 "使用 `w` 模式时,如果文件不存在会被创建,我们可以查看是否真的写入成功:"
1654 ]
1655 },
1656 {
1657 "cell_type": "code",
1658 "execution_count": null,
1659 "metadata": {},
1660 "outputs": [],
1661 "source": [
1662 "print (open('myfile.txt').read())"
1663 ]
1664 },
1665 {
1666 "cell_type": "markdown",
1667 "metadata": {},
1668 "source": [
1669 "如果文件已经存在, `w` 模式会覆盖之前写的所有内容:"
1670 ]
1671 },
1672 {
1673 "cell_type": "code",
1674 "execution_count": null,
1675 "metadata": {},
1676 "outputs": [],
1677 "source": [
1678 "f = open('myfile.txt', 'w')\n",
1679 "f.write('another hello world!')\n",
1680 "f.close()\n",
1681 "print (open('myfile.txt').read())\n"
1682 ]
1683 },
1684 {
1685 "cell_type": "markdown",
1686 "metadata": {},
1687 "source": [
1688 "除了写入模式,还有追加模式 `a` ,追加模式不会覆盖之前已经写入的内容,而是在之后继续写入:"
1689 ]
1690 },
1691 {
1692 "cell_type": "code",
1693 "execution_count": null,
1694 "metadata": {},
1695 "outputs": [],
1696 "source": [
1697 "f = open('myfile.txt', 'a')\n",
1698 "f.write('... and more')\n",
1699 "f.close()\n",
1700 "print (open('myfile.txt').read())\n"
1701 ]
1702 },
1703 {
1704 "cell_type": "markdown",
1705 "metadata": {},
1706 "source": [
1707 "写入结束之后一定要将文件关闭,否则可能出现内容没有完全写入文件中的情况。\n",
1708 "\n",
1709 "还可以使用读写模式 `w+`:"
1710 ]
1711 },
1712 {
1713 "cell_type": "code",
1714 "execution_count": null,
1715 "metadata": {},
1716 "outputs": [],
1717 "source": [
1718 "f = open('myfile.txt', 'w+')\n",
1719 "f.write('hello world!')\n",
1720 "f.seek(6)\n",
1721 "print (f.read())\n",
1722 "f.close()\n"
1723 ]
1724 },
1725 {
1726 "cell_type": "markdown",
1727 "metadata": {},
1728 "source": [
1729 "这里 `f.seek(6)` 移动到文件的第6个字符处,然后 `f.read()` 读出剩下的内容。"
1730 ]
1731 },
1732 {
1733 "cell_type": "markdown",
1734 "metadata": {},
1735 "source": [
1736 "删除刚才创建的文件:"
1737 ]
1738 },
1739 {
1740 "cell_type": "code",
1741 "execution_count": null,
1742 "metadata": {},
1743 "outputs": [],
1744 "source": [
1745 "import os\n",
1746 "os.remove('myfile.txt')\n"
1747 ]
1748 },
1749 {
1750 "cell_type": "markdown",
1751 "metadata": {},
1752 "source": [
1753 "### 关闭文件\n",
1754 "\n",
1755 "在 **Python** 中,如果一个打开的文件不再被其他变量引用时,它会自动关闭这个文件。\n",
1756 "\n",
1757 "所以正常情况下,如果一个文件正常被关闭了,忘记调用文件的 `close` 方法不会有什么问题。\n",
1758 "\n",
1759 "关闭文件可以保证内容已经被写入文件,而不关闭可能会出现意想不到的结果:"
1760 ]
1761 },
1762 {
1763 "cell_type": "code",
1764 "execution_count": null,
1765 "metadata": {},
1766 "outputs": [],
1767 "source": [
1768 "f = open('newfile.txt','w')\n",
1769 "f.write('hello world')\n",
1770 "g = open('newfile.txt', 'r')\n",
1771 "print (repr(g.read()))\n"
1772 ]
1773 },
1774 {
1775 "cell_type": "markdown",
1776 "metadata": {},
1777 "source": [
1778 "虽然这里写了内容,但是在关闭之前,这个内容并没有被写入磁盘。\n",
1779 "\n",
1780 "使用循环写入的内容也并不完整:"
1781 ]
1782 },
1783 {
1784 "cell_type": "code",
1785 "execution_count": null,
1786 "metadata": {},
1787 "outputs": [],
1788 "source": [
1789 "f = open('newfile.txt','w')\n",
1790 "for i in range(30):\n",
1791 " f.write('hello world: ' + str(i) + '\\n')\n",
1792 "\n",
1793 "g = open('newfile.txt', 'r')\n",
1794 "print (g.read())\n",
1795 "f.close()\n",
1796 "g.close()\n"
1797 ]
1798 },
1799 {
1800 "cell_type": "code",
1801 "execution_count": null,
1802 "metadata": {},
1803 "outputs": [],
1804 "source": [
1805 "import os\n",
1806 "os.remove('newfile.txt')\n"
1807 ]
1808 },
1809 {
1810 "cell_type": "markdown",
1811 "metadata": {},
1812 "source": [
1813 "出现异常时候的读写:"
1814 ]
1815 },
1816 {
1817 "cell_type": "code",
1818 "execution_count": null,
1819 "metadata": {},
1820 "outputs": [],
1821 "source": [
1822 "f = open('newfile.txt','w')\n",
1823 "for i in range(30):\n",
1824 " x = 1.0 / (i - 10)\n",
1825 " f.write('hello world: ' + str(i) + '\\n')\n"
1826 ]
1827 },
1828 {
1829 "cell_type": "markdown",
1830 "metadata": {},
1831 "source": [
1832 "查看已有内容:"
1833 ]
1834 },
1835 {
1836 "cell_type": "code",
1837 "execution_count": null,
1838 "metadata": {},
1839 "outputs": [],
1840 "source": [
1841 "g = open('newfile.txt', 'r')\n",
1842 "print (g.read())\n",
1843 "f.close()\n",
1844 "g.close()\n"
1845 ]
1846 },
1847 {
1848 "cell_type": "markdown",
1849 "metadata": {},
1850 "source": [
1851 "可以看到,出现异常的时候,磁盘的写入并没有完成,为此我们可以使用 `try/except/finally` 块来关闭文件,这里 `finally` 确保关闭文件,所有的写入已经完成。"
1852 ]
1853 },
1854 {
1855 "cell_type": "code",
1856 "execution_count": null,
1857 "metadata": {},
1858 "outputs": [],
1859 "source": [
1860 "f = open('newfile.txt','w')\n",
1861 "try:\n",
1862 " for i in range(30):\n",
1863 " x = 1.0 / (i - 10)\n",
1864 " f.write('hello world: ' + str(i) + '\\n')\n",
1865 "except Exception:\n",
1866 " print(\"something bad happened\")\n",
1867 "finally:\n",
1868 " f.close()\n"
1869 ]
1870 },
1871 {
1872 "cell_type": "code",
1873 "execution_count": null,
1874 "metadata": {},
1875 "outputs": [],
1876 "source": [
1877 "g = open('newfile.txt', 'r')\n",
1878 "print(g.read())\n",
1879 "g.close()\n"
1880 ]
1881 },
1882 {
1883 "cell_type": "markdown",
1884 "metadata": {},
1885 "source": [
1886 "### with 方法\n",
1887 "\n",
1888 "事实上,**Python** 提供了更安全的方法,当 `with` 块的内容结束后,**Python** 会自动调用它的`close` 方法,确保读写的安全:"
1889 ]
1890 },
1891 {
1892 "cell_type": "code",
1893 "execution_count": null,
1894 "metadata": {},
1895 "outputs": [],
1896 "source": [
1897 "with open('newfile.txt','w') as f:\n",
1898 " for i in range(30):\n",
1899 " x = 1.0 / (i - 10)\n",
1900 " f.write('hello world: ' + str(i) + '\\n')\n"
1901 ]
1902 },
1903 {
1904 "cell_type": "markdown",
1905 "metadata": {},
1906 "source": [
1907 "与 `try/exception/finally` 效果相同,但更简单。"
1908 ]
1909 },
1910 {
1911 "cell_type": "code",
1912 "execution_count": null,
1913 "metadata": {},
1914 "outputs": [],
1915 "source": [
1916 "g = open('newfile.txt', 'r')\n",
1917 "print(g.read())\n",
1918 "g.close()\n"
1919 ]
1920 },
1921 {
1922 "cell_type": "markdown",
1923 "metadata": {},
1924 "source": [
1925 "所以,写文件时候要确保文件被正确关闭。\n",
1926 "\n",
1927 "删除刚才创建的文件:"
1928 ]
1929 },
1930 {
1931 "cell_type": "code",
1932 "execution_count": null,
1933 "metadata": {},
1934 "outputs": [],
1935 "source": [
1936 "import os\n",
1937 "os.remove('newfile.txt')\n"
1938 ]
1939 },
1940 {
1941 "cell_type": "markdown",
1942 "metadata": {},
1943 "source": [
1944 "## 2.7 CSV 文件和 csv 模块\n",
1945 "\n",
1946 "标准库中有自带的 `csv` 模块处理 `csv` 格式的文件:"
1947 ]
1948 },
1949 {
1950 "cell_type": "code",
1951 "execution_count": null,
1952 "metadata": {},
1953 "outputs": [],
1954 "source": [
1955 "import csv"
1956 ]
1957 },
1958 {
1959 "cell_type": "markdown",
1960 "metadata": {},
1961 "source": [
1962 "### 读 csv 文件\n",
1963 "\n",
1964 "假设我们有这样的一个文件:"
1965 ]
1966 },
1967 {
1968 "cell_type": "code",
1969 "execution_count": null,
1970 "metadata": {},
1971 "outputs": [],
1972 "source": [
1973 "%%file data.csv\n",
1974 "\"alpha 1\", 100, -1.443\n",
1975 "\"beat 3\", 12, -0.0934\n",
1976 "\"gamma 3a\", 192, -0.6621\n",
1977 "\"delta 2a\", 15, -4.515\n"
1978 ]
1979 },
1980 {
1981 "cell_type": "markdown",
1982 "metadata": {},
1983 "source": [
1984 "打开这个文件,并产生一个文件 reader:"
1985 ]
1986 },
1987 {
1988 "cell_type": "code",
1989 "execution_count": null,
1990 "metadata": {},
1991 "outputs": [],
1992 "source": [
1993 "# 打开 data.csv 文件\n",
1994 "fp = open(\"data.csv\")\n",
1995 "\n",
1996 "# 读取文件“”\n",
1997 "r = csv.reader(fp)\n",
1998 "\n",
1999 "# 可以按行迭代数据\n",
2000 "for row in r:\n",
2001 " print (row)\n",
2002 "\n",
2003 "# 关闭文件\n",
2004 "fp.close()\n"
2005 ]
2006 },
2007 {
2008 "cell_type": "markdown",
2009 "metadata": {},
2010 "source": [
2011 "默认数据内容都被当作字符串处理,不过可以自己进行处理:"
2012 ]
2013 },
2014 {
2015 "cell_type": "code",
2016 "execution_count": null,
2017 "metadata": {},
2018 "outputs": [],
2019 "source": [
2020 "data = []\n",
2021 "\n",
2022 "with open('data.csv') as fp:\n",
2023 " r = csv.reader(fp)\n",
2024 " for row in r:\n",
2025 " data.append([row[0], int(row[1]), float(row[2])])\n",
2026 "\n",
2027 "data\n"
2028 ]
2029 },
2030 {
2031 "cell_type": "markdown",
2032 "metadata": {},
2033 "source": [
2034 "清除刚刚创建的文件:"
2035 ]
2036 },
2037 {
2038 "cell_type": "code",
2039 "execution_count": null,
2040 "metadata": {},
2041 "outputs": [],
2042 "source": [
2043 "import os\n",
2044 "os.remove('data.csv')"
2045 ]
2046 },
2047 {
2048 "cell_type": "markdown",
2049 "metadata": {},
2050 "source": [
2051 "### 写 csv 文件"
2052 ]
2053 },
2054 {
2055 "cell_type": "markdown",
2056 "metadata": {},
2057 "source": [
2058 "可以使用 `csv.writer` 写入文件,不过相应地,传入的应该是以写方式打开的文件,不过一般要用 `'wb'` 即二进制写入方式,防止出现换行不正确的问题:"
2059 ]
2060 },
2061 {
2062 "cell_type": "code",
2063 "execution_count": null,
2064 "metadata": {},
2065 "outputs": [],
2066 "source": [
2067 "data = [('one', 1, 1.5), ('two', 2, 8.0)]\n",
2068 "with open('out.csv', 'w') as fp:\n",
2069 " w = csv.writer(fp)\n",
2070 " w.writerows(data)\n"
2071 ]
2072 },
2073 {
2074 "cell_type": "markdown",
2075 "metadata": {},
2076 "source": [
2077 "显示结果:"
2078 ]
2079 },
2080 {
2081 "cell_type": "code",
2082 "execution_count": null,
2083 "metadata": {},
2084 "outputs": [],
2085 "source": [
2086 "! cat 'out.csv'"
2087 ]
2088 }
2089 ],
2090 "metadata": {
2091 "kernelspec": {
2092 "display_name": "Python 3",
2093 "language": "python",
2094 "name": "python3"
2095 },
2096 "language_info": {
2097 "codemirror_mode": {
2098 "name": "ipython",
2099 "version": 3
2100 },
2101 "file_extension": ".py",
2102 "mimetype": "text/x-python",
2103 "name": "python",
2104 "nbconvert_exporter": "python",
2105 "pygments_lexer": "ipython3",
2106 "version": "3.7.4"
2107 }
2108 },
2109 "nbformat": 4,
2110 "nbformat_minor": 2
2111 }
0 {
1 "cells": [
2 {
3 "cell_type": "markdown",
4 "metadata": {},
5 "source": [
6 "# 3. 机器学习常用的包"
7 ]
8 },
9 {
10 "cell_type": "markdown",
11 "metadata": {},
12 "source": [
13 "## 3.1 `NumPy`"
14 ]
15 },
16 {
17 "cell_type": "markdown",
18 "metadata": {},
19 "source": [
20 "<img src=\"http://imgbed.momodel.cn/1200px_NumPy_logo.svg.png\" width=300>\n"
21 ]
22 },
23 {
24 "cell_type": "markdown",
25 "metadata": {},
26 "source": [
27 "`NumPy(Numerical Python)`是一个开源的 **Python** 科学计算库,用于快速处理任意维度的数组。\n",
28 "\n",
29 "`NumPy` 支持常见的数组和矩阵操作。\n",
30 "\n",
31 "对于同样的数值计算任务,使用 `NumPy` 比直接使用 **Python** 要简洁的多。\n",
32 "\n",
33 "`NumPy` 使用 `ndarray` 对象来处理多维数组,该对象是一个快速而灵活的大数据容器。\n"
34 ]
35 },
36 {
37 "cell_type": "markdown",
38 "metadata": {},
39 "source": [
40 "### `ndarray` 介绍"
41 ]
42 },
43 {
44 "cell_type": "markdown",
45 "metadata": {},
46 "source": [
47 "`NumPy` 提供了一个`N` 维数组类型 `ndarray`,它描述了**相同类型**的 `items` 的集合。\n",
48 " \n",
49 "|语文|数学|英语|政治|体育|\n",
50 "|--|--|--|--|--|\n",
51 "|80|89|86|67|79|\n",
52 "|78|97|89|76|81|\n",
53 "\n",
54 "用 `ndarray` 进行存储:"
55 ]
56 },
57 {
58 "cell_type": "code",
59 "execution_count": null,
60 "metadata": {},
61 "outputs": [],
62 "source": [
63 "import numpy as np\n",
64 "\n",
65 "# 创建ndarray\n",
66 "score = np.array([[80, 89, 86, 67, 79],[78, 97, 89, 67, 81]])\n",
67 "\n",
68 "# 打印结果\n",
69 "score\n"
70 ]
71 },
72 {
73 "cell_type": "markdown",
74 "metadata": {},
75 "source": [
76 "### `ndarray` 的属性 \n",
77 "数组属性反映了数组本身固有的信息。\n",
78 "\n",
79 "|属性名字|\t属性解释|\n",
80 "|--|--|\n",
81 "|ndarray.shape|\t数组维度的元组|\n",
82 "|ndarray.ndim|\t数组维数|\n",
83 "|ndarray.size|\t数组中的元素数量|\n",
84 "|ndarray.itemsize|\t一个数组元素的长度(字节)|\n",
85 "|ndarray.dtype|\t数组元素的类型|\n",
86 "\n",
87 "\n",
88 "\n"
89 ]
90 },
91 {
92 "cell_type": "markdown",
93 "metadata": {},
94 "source": [
95 "+ `shape`:数组形状"
96 ]
97 },
98 {
99 "cell_type": "code",
100 "execution_count": null,
101 "metadata": {},
102 "outputs": [],
103 "source": [
104 "import numpy as np\n",
105 "\n",
106 "# 创建不同形状的数组\n",
107 "# 创建不同形状的数组\n",
108 "a = np.array([[1,2,3],[4,5,6]])\n",
109 "b = np.array([1,2,3,4])\n",
110 "c = np.array([\n",
111 " [\n",
112 " [1,2,3],[4,5,6]\n",
113 " ],\n",
114 " [\n",
115 " [1,2,3],[4,5,6]\n",
116 " ]\n",
117 "])\n",
118 "\n",
119 "# 分别打印出形状\n",
120 "print(a.shape)\n",
121 "print(b.shape)\n",
122 "print(c.shape)\n"
123 ]
124 },
125 {
126 "cell_type": "markdown",
127 "metadata": {},
128 "source": [
129 "+ `ndim`:数组维数"
130 ]
131 },
132 {
133 "cell_type": "code",
134 "execution_count": null,
135 "metadata": {},
136 "outputs": [],
137 "source": [
138 "import numpy as np\n",
139 "\n",
140 "# 创建不同形状的数组\n",
141 "a = np.array([[1,2,3],[4,5,6]])\n",
142 "b = np.array([1,2,3,4])\n",
143 "c = np.array([[[1,2,3],[4,5,6]], [[1,2,3],[4,5,6]]])\n",
144 "\n",
145 "# 分别打印出维数\n",
146 "print(a.ndim)\n",
147 "print(b.ndim)\n",
148 "print(c.ndim)\n"
149 ]
150 },
151 {
152 "cell_type": "markdown",
153 "metadata": {},
154 "source": [
155 "+ `size`:数组元素数量"
156 ]
157 },
158 {
159 "cell_type": "code",
160 "execution_count": null,
161 "metadata": {},
162 "outputs": [],
163 "source": [
164 "import numpy as np\n",
165 "\n",
166 "# 创建不同形状的数组\n",
167 "a = np.array([[1,2,3],[4,5,6]])\n",
168 "b = np.array([1,2,3,4])\n",
169 "c = np.array([[[1,2,3],[4,5,6]],[[1,2,3],[4,5,6]]])\n",
170 "\n",
171 "# 分别打印出数组元素数量\n",
172 "print(a.size)\n",
173 "print(b.size)\n",
174 "print(c.size)\n"
175 ]
176 },
177 {
178 "cell_type": "markdown",
179 "metadata": {},
180 "source": [
181 "+ `itemsize`:数组元素的长度"
182 ]
183 },
184 {
185 "cell_type": "code",
186 "execution_count": null,
187 "metadata": {},
188 "outputs": [],
189 "source": [
190 "import numpy as np\n",
191 "\n",
192 "# 创建不同形状的数组\n",
193 "a = np.array([[1,2,3],[4,5,6]])\n",
194 "b = np.array([1,2,3,4])\n",
195 "c = np.array([[[1,2,3],[4,5,6]],[[1,2,3],[4,5,60]]])\n",
196 "\n",
197 "# 分别打印出数组元素数量\n",
198 "print(a.itemsize)\n",
199 "print(b.itemsize)\n",
200 "print(c.itemsize)\n"
201 ]
202 },
203 {
204 "cell_type": "markdown",
205 "metadata": {},
206 "source": [
207 "+ `dtype`:数组元素的类型"
208 ]
209 },
210 {
211 "cell_type": "code",
212 "execution_count": null,
213 "metadata": {},
214 "outputs": [],
215 "source": [
216 "import numpy as np\n",
217 "\n",
218 "# 创建不同形状的数组\n",
219 "a = np.array([[1,2,3],[4,5,6]])\n",
220 "b = np.array([1,2,3,4])\n",
221 "c = np.array([[[1,2,3],[4,5,6]],[[1,2,3],[4,5,6.0]]])\n",
222 "\n",
223 "# 分别打印出数组元素数量\n",
224 "print(a.dtype)\n",
225 "print(b.dtype)\n",
226 "print(c.dtype)\n"
227 ]
228 },
229 {
230 "cell_type": "markdown",
231 "metadata": {},
232 "source": [
233 "### `ndarray` 的类型\n",
234 "\n",
235 "|名称|\t描述|\t简写|\n",
236 "|--|--|--|\n",
237 "|np.bool|\t用一个字节存储的布尔类型(True或False)|\t'b'|\n",
238 "|np.int8|\t一个字节大小,-128 至 127|\t'i'|\n",
239 "|np.int16|\t整数,-32768 至 32767|\t'i2'|\n",
240 "|np.int32|\t整数,$-2^{31}$ 至 $2^{32} -1$\t|'i4'|\n",
241 "|np.int64|\t整数,$-2^{63}$ 至 $2^{63} - 1$\t|'i8'|\n",
242 "|np.uint8|\t无符号整数,0 至 255|\t'u'|\n",
243 "|np.uint16\t|无符号整数,0 至 65535|\t'u2'|\n",
244 "|np.uint32|\t无符号整数,0 至 $2^{32} - 1$\t|'u4'|\n",
245 "|np.uint64|\t无符号整数,0 至 $2^{64} - 1$ |'u8'|\n",
246 "|np.float16\t|半精度浮点数:16位,正负号1位,指数5位,精度10位\t|'f2'|\n",
247 "|np.float32\t|单精度浮点数:32位,正负号1位,指数8位,精度23位\t|'f4'|\n",
248 "|np.float64\t|双精度浮点数:64位,正负号1位,指数11位,精度52位\t|'f8'|\n",
249 "|np.complex64\t|复数,分别用两个32位浮点数表示实部和虚部\t|'c8'|\n",
250 "|np.complex128\t|复数,分别用两个64位浮点数表示实部和虚部\t|'c16'|\n",
251 "|np.object_\t|python对象\t|'O'|\n",
252 "|np.string_\t|字符串\t|'S'|\n",
253 "|np.unicode_\t|unicode类型\t|'U'|\n",
254 "\n",
255 "**注意:创建数组的时候指定类型**"
256 ]
257 },
258 {
259 "cell_type": "code",
260 "execution_count": null,
261 "metadata": {},
262 "outputs": [],
263 "source": [
264 "import numpy as np\n",
265 "\n",
266 "# 创建数组时指定类型为 np.float32\n",
267 "a = np.array([[1, 2, 3],[4, 5, 6]], dtype=np.float32)\n",
268 "\n",
269 "# 创建数组时未指定类型\n",
270 "b = np.array([[1, 2, 3],[4, 5, 6]])\n",
271 "\n",
272 "# 打印结果\n",
273 "print(\"数组a:\\n%s,\\n数据类型:%s\"%(a,a.dtype))\n",
274 "print(\"数组b:\\n%s,\\n数据类型:%s\"%(b,b.dtype))\n"
275 ]
276 },
277 {
278 "cell_type": "markdown",
279 "metadata": {},
280 "source": [
281 "### 基本操作"
282 ]
283 },
284 {
285 "cell_type": "markdown",
286 "metadata": {},
287 "source": [
288 "#### 生成 `0 ` 和 `1` 数组的常见方法 \n",
289 "\n",
290 "+ 生成 `0` 的数组"
291 ]
292 },
293 {
294 "cell_type": "code",
295 "execution_count": null,
296 "metadata": {},
297 "outputs": [],
298 "source": [
299 "import numpy as np\n",
300 "\n",
301 "zero = np.zeros([3, 4])\n",
302 "zero\n"
303 ]
304 },
305 {
306 "cell_type": "markdown",
307 "metadata": {},
308 "source": [
309 "+ 生成 `1` 的数组"
310 ]
311 },
312 {
313 "cell_type": "code",
314 "execution_count": null,
315 "metadata": {},
316 "outputs": [],
317 "source": [
318 "one = np.ones([3,4])\n",
319 "one\n"
320 ]
321 },
322 {
323 "cell_type": "markdown",
324 "metadata": {},
325 "source": [
326 "+ 生成对角数组(对角线的地方是 `1`,其余地方是 `0`)"
327 ]
328 },
329 {
330 "cell_type": "code",
331 "execution_count": null,
332 "metadata": {},
333 "outputs": [],
334 "source": [
335 "eyes = np.eye(10,5)\n",
336 "eyes\n"
337 ]
338 },
339 {
340 "cell_type": "markdown",
341 "metadata": {},
342 "source": [
343 "+ 创建方阵对角矩阵"
344 ]
345 },
346 {
347 "cell_type": "code",
348 "execution_count": null,
349 "metadata": {},
350 "outputs": [],
351 "source": [
352 "# np.eye()输入数据相等则是方阵\n",
353 "eyes1 = np.eye(5)\n",
354 "eyes1\n"
355 ]
356 },
357 {
358 "cell_type": "markdown",
359 "metadata": {},
360 "source": [
361 "#### 从现有数组生成"
362 ]
363 },
364 {
365 "cell_type": "code",
366 "execution_count": null,
367 "metadata": {},
368 "outputs": [],
369 "source": [
370 "a = [[1,2,3],[4,5,6]]\n",
371 "\n",
372 "# 从现有的数组当中创建\n",
373 "a1 = np.array(a)\n",
374 "a\n"
375 ]
376 },
377 {
378 "cell_type": "code",
379 "execution_count": null,
380 "metadata": {},
381 "outputs": [],
382 "source": [
383 "a1\n"
384 ]
385 },
386 {
387 "cell_type": "markdown",
388 "metadata": {},
389 "source": [
390 "#### 生成固定范围的数组"
391 ]
392 },
393 {
394 "cell_type": "code",
395 "execution_count": null,
396 "metadata": {},
397 "outputs": [],
398 "source": [
399 "# 生成等间隔的数组\n",
400 "a = np.linspace(0, 90, 10)\n",
401 "a\n"
402 ]
403 },
404 {
405 "cell_type": "code",
406 "execution_count": null,
407 "metadata": {},
408 "outputs": [],
409 "source": [
410 "# 生成等间隔的数组\n",
411 "b = np.arange(0, 90, 10)\n",
412 "b\n"
413 ]
414 },
415 {
416 "cell_type": "markdown",
417 "metadata": {},
418 "source": [
419 "#### 形状修改"
420 ]
421 },
422 {
423 "cell_type": "code",
424 "execution_count": null,
425 "metadata": {},
426 "outputs": [],
427 "source": [
428 "from numpy import array\n",
429 "a = array([[ 0, 1, 2, 3, 4, 5],\n",
430 " [10,11,12,13,14,15],\n",
431 " [20,21,22,23,24,25],\n",
432 " [30,31,32,33,34,35]])\n",
433 "a.shape\n"
434 ]
435 },
436 {
437 "cell_type": "code",
438 "execution_count": null,
439 "metadata": {},
440 "outputs": [],
441 "source": [
442 "# 在转换形状的时候,一定要注意数组的元素匹配\n",
443 "# 只是将形状进行了修改,但并没有将行列进行转换\n",
444 "b = a.reshape([3,8])\n",
445 "b\n"
446 ]
447 },
448 {
449 "cell_type": "code",
450 "execution_count": null,
451 "metadata": {},
452 "outputs": [],
453 "source": [
454 "# 数组的形状被修改为: (4, 20), -1: 表示通过待计算\n",
455 "c = a.reshape([-1,12])\n",
456 "c\n"
457 ]
458 },
459 {
460 "cell_type": "code",
461 "execution_count": null,
462 "metadata": {},
463 "outputs": [],
464 "source": [
465 "d = a.T\n",
466 "d.shape\n"
467 ]
468 },
469 {
470 "cell_type": "markdown",
471 "metadata": {},
472 "source": [
473 "#### 类型修改"
474 ]
475 },
476 {
477 "cell_type": "code",
478 "execution_count": null,
479 "metadata": {},
480 "outputs": [],
481 "source": [
482 "arr = np.array([[[1, 2, 3], [4, 5, 6]], [[12, 3, 34], [5, 6, 7]]])\n",
483 "arr.dtype\n"
484 ]
485 },
486 {
487 "cell_type": "code",
488 "execution_count": null,
489 "metadata": {},
490 "outputs": [],
491 "source": [
492 "arr.astype(np.float32)\n"
493 ]
494 },
495 {
496 "cell_type": "markdown",
497 "metadata": {},
498 "source": [
499 "#### 数组去重"
500 ]
501 },
502 {
503 "cell_type": "code",
504 "execution_count": null,
505 "metadata": {},
506 "outputs": [],
507 "source": [
508 "arr = np.array([[1, 2, 3, 4],[3, 4, 5, 6]])\n",
509 "np.unique(arr)\n"
510 ]
511 },
512 {
513 "cell_type": "markdown",
514 "metadata": {},
515 "source": [
516 "### 数组运算\n",
517 "\n",
518 "数组的算术运算是元素级别的操作,新的数组被创建并且被结果填充。"
519 ]
520 },
521 {
522 "cell_type": "markdown",
523 "metadata": {},
524 "source": [
525 "运算|函数\n",
526 "--- | --- \n",
527 "`a + b` | `add(a,b)`\n",
528 "`a - b` | `subtract(a,b)`\n",
529 "`a * b` | `multiply(a,b)`\n",
530 "`a / b` | `divide(a,b)`\n",
531 "`a ** b` | `power(a,b)`\n",
532 "`a % b` | `remainder(a,b)`\n",
533 "\n",
534 "以乘法为例,数组与标量相乘,相当于数组的每个元素乘以这个标量:"
535 ]
536 },
537 {
538 "cell_type": "code",
539 "execution_count": null,
540 "metadata": {},
541 "outputs": [],
542 "source": [
543 "import numpy as np\n",
544 "a = np.array([1,2,3,4])\n",
545 "a * 3\n"
546 ]
547 },
548 {
549 "cell_type": "markdown",
550 "metadata": {},
551 "source": [
552 "数组逐元素相乘:"
553 ]
554 },
555 {
556 "cell_type": "code",
557 "execution_count": null,
558 "metadata": {},
559 "outputs": [],
560 "source": [
561 "a = np.array([1,2])\n",
562 "b = np.array([3,4])\n",
563 "a * b\n"
564 ]
565 },
566 {
567 "cell_type": "markdown",
568 "metadata": {},
569 "source": [
570 "使用函数"
571 ]
572 },
573 {
574 "cell_type": "code",
575 "execution_count": null,
576 "metadata": {},
577 "outputs": [],
578 "source": [
579 "np.multiply(a, b)\n"
580 ]
581 },
582 {
583 "cell_type": "markdown",
584 "metadata": {},
585 "source": [
586 "函数还可以接受第三个参数,表示将结果存入第三个参数中:"
587 ]
588 },
589 {
590 "cell_type": "code",
591 "execution_count": null,
592 "metadata": {},
593 "outputs": [],
594 "source": [
595 "np.multiply(a, b, a)\n"
596 ]
597 },
598 {
599 "cell_type": "code",
600 "execution_count": null,
601 "metadata": {},
602 "outputs": [],
603 "source": [
604 "a\n"
605 ]
606 },
607 {
608 "cell_type": "markdown",
609 "metadata": {},
610 "source": [
611 "### 矩阵 \n",
612 "使用 `mat` 方法将 `2` 维数组转化为矩阵:"
613 ]
614 },
615 {
616 "cell_type": "code",
617 "execution_count": null,
618 "metadata": {},
619 "outputs": [],
620 "source": [
621 "import numpy as np\n",
622 "a = np.array([[1,2,4],\n",
623 " [2,5,3],\n",
624 " [7,8,9]])\n",
625 "A = np.mat(a)\n",
626 "A\n"
627 ]
628 },
629 {
630 "cell_type": "code",
631 "execution_count": null,
632 "metadata": {},
633 "outputs": [],
634 "source": [
635 "# 也可以使用 **Matlab** 的语法传入一个字符串来生成矩阵:\n",
636 "A = np.mat('1,2,4;2,5,3;7,8,9')\n",
637 "A\n"
638 ]
639 },
640 {
641 "cell_type": "markdown",
642 "metadata": {},
643 "source": [
644 "矩阵与向量的乘法:"
645 ]
646 },
647 {
648 "cell_type": "code",
649 "execution_count": null,
650 "metadata": {},
651 "outputs": [],
652 "source": [
653 "x = np.array([[1], [2], [3]])\n",
654 "x\n"
655 ]
656 },
657 {
658 "cell_type": "code",
659 "execution_count": null,
660 "metadata": {},
661 "outputs": [],
662 "source": [
663 "A*x\n"
664 ]
665 },
666 {
667 "cell_type": "code",
668 "execution_count": null,
669 "metadata": {},
670 "outputs": [],
671 "source": [
672 "b = np.array([[1,2],\n",
673 " [3,4],\n",
674 " [5,6]])\n",
675 "B = np.mat(b)\n",
676 "A*B\n"
677 ]
678 },
679 {
680 "cell_type": "markdown",
681 "metadata": {},
682 "source": [
683 "`A.I` 表示 `A` 矩阵的逆矩阵:"
684 ]
685 },
686 {
687 "cell_type": "code",
688 "execution_count": null,
689 "metadata": {},
690 "outputs": [],
691 "source": [
692 "A.I\n"
693 ]
694 },
695 {
696 "cell_type": "markdown",
697 "metadata": {},
698 "source": [
699 "矩阵指数表示矩阵连乘:"
700 ]
701 },
702 {
703 "cell_type": "code",
704 "execution_count": null,
705 "metadata": {},
706 "outputs": [],
707 "source": [
708 "A ** 4\n"
709 ]
710 },
711 {
712 "cell_type": "markdown",
713 "metadata": {},
714 "source": [
715 "### 统计函数"
716 ]
717 },
718 {
719 "cell_type": "markdown",
720 "metadata": {},
721 "source": [
722 "|方法|作用|\n",
723 "|--|--|\n",
724 "|`a.sum(axis=None)`|求和|\n",
725 "|`a.prod(axis=None)`|求积|\n",
726 "|`a.min(axis=None)`|最小值|\n",
727 "|`a.max(axis=None)`|最大值|\n",
728 "|`a.argmin(axis=None)`|最小值索引|\n",
729 "|`a.argmax(axis=None)`|最大值索引|\n",
730 "|`a.ptp(axis=None)`|最大值减最小值|\n",
731 "|`a.mean(axis=None)`|平均值|\n",
732 "|`a.std(axis=None)`|标准差|\n",
733 "|`a.var(axis=None)`|方差|"
734 ]
735 },
736 {
737 "cell_type": "code",
738 "execution_count": null,
739 "metadata": {
740 "code_folding": []
741 },
742 "outputs": [],
743 "source": [
744 "from numpy import array\n",
745 "a = array([[1,2,3],\n",
746 " [4,5,6]])\n",
747 "a\n"
748 ]
749 },
750 {
751 "cell_type": "markdown",
752 "metadata": {},
753 "source": [
754 "求所有元素的和:"
755 ]
756 },
757 {
758 "cell_type": "code",
759 "execution_count": null,
760 "metadata": {},
761 "outputs": [],
762 "source": [
763 "sum(a)\n"
764 ]
765 },
766 {
767 "cell_type": "code",
768 "execution_count": null,
769 "metadata": {},
770 "outputs": [],
771 "source": [
772 "a.sum()\n"
773 ]
774 },
775 {
776 "cell_type": "markdown",
777 "metadata": {},
778 "source": [
779 "**指定求和的维度**:\n",
780 "沿着第一维求和"
781 ]
782 },
783 {
784 "cell_type": "code",
785 "execution_count": null,
786 "metadata": {},
787 "outputs": [],
788 "source": [
789 "np.sum(a, axis=0)\n"
790 ]
791 },
792 {
793 "cell_type": "code",
794 "execution_count": null,
795 "metadata": {},
796 "outputs": [],
797 "source": [
798 "a.sum(axis=0)\n"
799 ]
800 },
801 {
802 "cell_type": "markdown",
803 "metadata": {},
804 "source": [
805 "沿着第二维求和:"
806 ]
807 },
808 {
809 "cell_type": "code",
810 "execution_count": null,
811 "metadata": {},
812 "outputs": [],
813 "source": [
814 "np.sum(a, axis=1)\n"
815 ]
816 },
817 {
818 "cell_type": "code",
819 "execution_count": null,
820 "metadata": {},
821 "outputs": [],
822 "source": [
823 "a.sum(axis=1)\n"
824 ]
825 },
826 {
827 "cell_type": "markdown",
828 "metadata": {},
829 "source": [
830 "沿着最后一维求和:"
831 ]
832 },
833 {
834 "cell_type": "code",
835 "execution_count": null,
836 "metadata": {},
837 "outputs": [],
838 "source": [
839 "np.sum(a, axis=-1)\n"
840 ]
841 },
842 {
843 "cell_type": "code",
844 "execution_count": null,
845 "metadata": {},
846 "outputs": [],
847 "source": [
848 "a.sum(axis=-1)\n"
849 ]
850 },
851 {
852 "cell_type": "markdown",
853 "metadata": {},
854 "source": [
855 "### 比较和逻辑函数"
856 ]
857 },
858 {
859 "cell_type": "markdown",
860 "metadata": {},
861 "source": [
862 "运算|函数|\n",
863 "--- | --- \n",
864 "`==` | `equal`\n",
865 "`!=` | `not_equal`\n",
866 "`>` | `greater`\n",
867 "`>=` | `greater_equal`\n",
868 "`<` | `less`\n",
869 "`<=` | `less_equal`\n",
870 "`^` | `bitwise_xor`\n",
871 "`~` | `invert`\n",
872 "`>>` | `right_shift`\n",
873 "`<<` | `left_shift`\n",
874 "\n",
875 "数组元素的比对,我们可以直接使用运算符进行比较,比如判断数组中元素是否大于某个数:"
876 ]
877 },
878 {
879 "cell_type": "code",
880 "execution_count": null,
881 "metadata": {},
882 "outputs": [],
883 "source": [
884 "from numpy import array\n",
885 "a = array([[ 0, 1, 2, 3, 4, 5],\n",
886 " [10,11,12,13,14,15],\n",
887 " [20,21,22,23,24,25],\n",
888 " [30,31,32,33,34,35]])\n",
889 "\n",
890 "a > 10\n"
891 ]
892 },
893 {
894 "cell_type": "code",
895 "execution_count": null,
896 "metadata": {},
897 "outputs": [],
898 "source": [
899 "# 判断数组中元素大于10的元素赋值为 -10 \n",
900 "a[a > 10] = -10\n",
901 "a\n"
902 ]
903 },
904 {
905 "cell_type": "markdown",
906 "metadata": {},
907 "source": [
908 "但是当数组元素较多时,查看输出结果便变得很麻烦,这时我们可以使用`all()`方法,直接比对矩阵的所有对应的元素是否满足条件。假如判断某个区间的值是否全是大于 `20`:"
909 ]
910 },
911 {
912 "cell_type": "code",
913 "execution_count": null,
914 "metadata": {},
915 "outputs": [],
916 "source": [
917 "from numpy import array\n",
918 "a = array([[ 0, 1, 2, 3, 4, 5],\n",
919 " [10,11,12,13,14,15],\n",
920 " [20,21,22,23,24,25],\n",
921 " [30,31,32,33,34,35]])\n",
922 "\n",
923 "a[1:3,1:3]\n"
924 ]
925 },
926 {
927 "cell_type": "code",
928 "execution_count": null,
929 "metadata": {},
930 "outputs": [],
931 "source": [
932 "np.all(a[1:4,1:3] > 20)\n"
933 ]
934 },
935 {
936 "cell_type": "markdown",
937 "metadata": {},
938 "source": [
939 "比如判断数组某个区间的元素是否存在大于 `20`:"
940 ]
941 },
942 {
943 "cell_type": "code",
944 "execution_count": null,
945 "metadata": {},
946 "outputs": [],
947 "source": [
948 "np.any(a[1:4,1:3] > 20)\n"
949 ]
950 },
951 {
952 "cell_type": "markdown",
953 "metadata": {},
954 "source": [
955 "### 广播机制"
956 ]
957 },
958 {
959 "cell_type": "markdown",
960 "metadata": {},
961 "source": [
962 "正常的加法:"
963 ]
964 },
965 {
966 "cell_type": "code",
967 "execution_count": null,
968 "metadata": {},
969 "outputs": [],
970 "source": [
971 "a = np.array([[ 0, 0, 0],\n",
972 " [10,10,10],\n",
973 " [20,20,20],\n",
974 " [30,30,30]])\n",
975 "b = np.array([[ 0, 1, 2],\n",
976 " [ 0, 1, 2],\n",
977 " [ 0, 1, 2],\n",
978 " [ 0, 1, 2]])\n",
979 "a + b\n"
980 ]
981 },
982 {
983 "cell_type": "markdown",
984 "metadata": {},
985 "source": [
986 "将 `b` 的值变成二维的 `[0,1,2]` 之后的加法:"
987 ]
988 },
989 {
990 "cell_type": "code",
991 "execution_count": null,
992 "metadata": {},
993 "outputs": [],
994 "source": [
995 "b = np.array([[0,1,2],[0,1,2]])\n",
996 "print(b.shape)\n",
997 "a + b\n"
998 ]
999 },
1000 {
1001 "cell_type": "markdown",
1002 "metadata": {},
1003 "source": [
1004 "两个 `ndarray` 执行的是对应元素的的运算,广播机制的功能是为了方便不同形状的`ndarray`(`NumPy` 库的核心数据结构)进行数学运算。\n",
1005 " \n",
1006 " 当操作两个数组时,`NumPy` 会逐个比较它们的`shape`(构成的元组`tuple`),只有在下述情况下,两个数组才能够进行数组与数组的运算。\n",
1007 "\n",
1008 "+ 维度相等\n",
1009 "+ `shape`(其中相对应的一个地方为 `1`)"
1010 ]
1011 },
1012 {
1013 "cell_type": "code",
1014 "execution_count": null,
1015 "metadata": {},
1016 "outputs": [],
1017 "source": [
1018 "# 将 `b` 的值变成一维的 `[0,1,2]` 之后的加法:\n",
1019 "b = np.array([0,1,2])\n",
1020 "\n",
1021 "a + b\n"
1022 ]
1023 },
1024 {
1025 "cell_type": "markdown",
1026 "metadata": {},
1027 "source": [
1028 "结果一样,虽然两个数组的维数不一样,但是 **Numpy** 检测到 `b` 的维度与 `a` 的维度匹配,所以将 `b` 扩展为之前的形式,得到相同的形状。\n",
1029 "\n",
1030 "对于更高维度,这样的扩展依然有效。 \n",
1031 "\n",
1032 "如果我们再将 `a` 变成一个列向量呢?"
1033 ]
1034 },
1035 {
1036 "cell_type": "code",
1037 "execution_count": null,
1038 "metadata": {},
1039 "outputs": [],
1040 "source": [
1041 "a = np.array([0,10,20,30])\n",
1042 "a.shape = 4,1\n",
1043 "a\n"
1044 ]
1045 },
1046 {
1047 "cell_type": "code",
1048 "execution_count": null,
1049 "metadata": {},
1050 "outputs": [],
1051 "source": [
1052 "a + b\n"
1053 ]
1054 },
1055 {
1056 "cell_type": "code",
1057 "execution_count": null,
1058 "metadata": {},
1059 "outputs": [],
1060 "source": [
1061 "# 此时a,b的维度分别为\n",
1062 "a.shape, b.shape\n"
1063 ]
1064 },
1065 {
1066 "cell_type": "markdown",
1067 "metadata": {},
1068 "source": [
1069 "可以看到,虽然两者的维度并不相同,但是 `Numpy` 还是根据两者的维度,自动将它们进行扩展然后进行计算。\n",
1070 "\n",
1071 "匹配会从最后一维开始进行,直到某一个的维度全部匹配为止,因此对于以下情况,`Numpy` 都会进行相应的匹配:\n",
1072 "\n",
1073 "A|B|Result\n",
1074 "---|---|---\n",
1075 "3d array: 256 x 256 x 3 | 1d array: 3 | 3d array: 256 x 256 x 3\n",
1076 "4d array: 8 x 1 x 6 x 1 | 3d array: 7 x 1 x 5 | 3d array: 8 x 7 x 6 x 5\n",
1077 "3d array: 5 x 4 x 3 | 1d array: 1 | 3d array: 5 x 4 x 3\n",
1078 "3d array: 15 x 4 x 13 | 1d array: 15 x 1 x 13 | 3d array: 15 x 4 x 13\n",
1079 "2d array: 4 x 1 | 1d array: 3 | 2d array: 4 x 3\n",
1080 "\n",
1081 "匹配成功后,`Numpy` 会进行运算得到相应的结果。\n",
1082 "\n",
1083 "当然,如果相应的维度不匹配,那么 `Numpy` 会报错:`ValueError`。"
1084 ]
1085 },
1086 {
1087 "cell_type": "markdown",
1088 "metadata": {},
1089 "source": [
1090 "### `IO` 操作"
1091 ]
1092 },
1093 {
1094 "cell_type": "markdown",
1095 "metadata": {},
1096 "source": [
1097 "读写各种格式的文件,如下表所示:\n",
1098 "\n",
1099 "文件格式|使用的包|函数\n",
1100 "----|----|----\n",
1101 "txt | numpy | loadtxt, genfromtxt, fromfile, savetxt, tofile\n",
1102 "csv | csv | reader, writer\n",
1103 "Matlab | scipy.io | loadmat, savemat\n",
1104 "hdf | pytables, h5py| \n",
1105 "NetCDF | netCDF4, scipy.io.netcdf | netCDF4.Dataset, scipy.io.netcdf.netcdf_file\n",
1106 "**文件格式**|**使用的包**|**备注**\n",
1107 "wav | scipy.io.wavfile | 音频文件\n",
1108 "jpeg,png,...| PIL, scipy.misc.pilutil | 图像文件\n",
1109 "fits | pyfits | 天文图像\n"
1110 ]
1111 },
1112 {
1113 "cell_type": "markdown",
1114 "metadata": {},
1115 "source": [
1116 "`savetxt` 可以将数组写入文件,默认使用科学计数法的形式保存:"
1117 ]
1118 },
1119 {
1120 "cell_type": "code",
1121 "execution_count": null,
1122 "metadata": {},
1123 "outputs": [],
1124 "source": [
1125 "import numpy as np\n",
1126 "\n",
1127 "data = np.array([[1,2],\n",
1128 " [3,4]])\n",
1129 "\n",
1130 "# 保存文件\n",
1131 "np.savetxt('out.txt', data)\n"
1132 ]
1133 },
1134 {
1135 "cell_type": "code",
1136 "execution_count": null,
1137 "metadata": {},
1138 "outputs": [],
1139 "source": [
1140 "# 读取文件\n",
1141 "with open('out.txt') as f:\n",
1142 " for line in f:\n",
1143 " print(line)\n"
1144 ]
1145 },
1146 {
1147 "cell_type": "code",
1148 "execution_count": null,
1149 "metadata": {},
1150 "outputs": [],
1151 "source": [
1152 "# 读取文件\n",
1153 "np.loadtxt('out.txt')\n"
1154 ]
1155 },
1156 {
1157 "cell_type": "markdown",
1158 "metadata": {},
1159 "source": [
1160 "## 2. `Pandas`"
1161 ]
1162 },
1163 {
1164 "cell_type": "markdown",
1165 "metadata": {},
1166 "source": [
1167 "<img src=\"https://pandas.pydata.org/_static/pandas_logo.png\" width=300/>\n",
1168 "\n",
1169 "+ `Pandas` 是基于 `NumPy` 的一种工具,该工具是为了解决数据分析任务而创建的\n",
1170 "+ `Pandas` 纳入了大量库及一些标准的数据模型,提供了高效的操作大型数据集所需要的工具\n",
1171 "+ `Pandas` 提供了大量能使我们快速便捷地处理数据的函数与方法\n",
1172 "+ 是 **Python** 成为强大而高效的数据分析环境的重要因素之一\n"
1173 ]
1174 },
1175 {
1176 "cell_type": "code",
1177 "execution_count": null,
1178 "metadata": {},
1179 "outputs": [],
1180 "source": [
1181 "import pandas as pd\n",
1182 "import numpy as np"
1183 ]
1184 },
1185 {
1186 "cell_type": "markdown",
1187 "metadata": {},
1188 "source": [
1189 "### 产生 `Pandas` 对象"
1190 ]
1191 },
1192 {
1193 "cell_type": "markdown",
1194 "metadata": {},
1195 "source": [
1196 "`pandas` 主要有两种基本的数据结构:\n",
1197 "\n",
1198 "- `Series`\n",
1199 " - `Series` 是带索引的一维数组,可存储整数、浮点数、字符串、**Python** 对象等类型的数据。\n",
1200 "- `DataFrame`\n",
1201 " - `DataFrame` 是由多种类型的列构成的二维标签数据结构,类似于 `Excel` 、`SQL` 表,或 `Series` 对象构成的字典。`DataFrame` 是最常用的 `Pandas` 对象。\n"
1202 ]
1203 },
1204 {
1205 "cell_type": "code",
1206 "execution_count": null,
1207 "metadata": {},
1208 "outputs": [],
1209 "source": [
1210 "# 生成 series\n",
1211 "s = pd.Series([1,3,5,np.nan,6,8])\n",
1212 "\n",
1213 "print(s)"
1214 ]
1215 },
1216 {
1217 "cell_type": "code",
1218 "execution_count": null,
1219 "metadata": {},
1220 "outputs": [],
1221 "source": [
1222 "# 生成 dataframe \n",
1223 "dates = pd.date_range('20200101', periods=15)\n",
1224 "\n",
1225 "df = pd.DataFrame(np.random.randn(15,4), index=dates, columns=list('ABCD'))\n",
1226 "\n",
1227 "df"
1228 ]
1229 },
1230 {
1231 "cell_type": "markdown",
1232 "metadata": {},
1233 "source": [
1234 "默认情况下,如果不指定 `index` 参数和 `columns`,那么他们的值将用从 `0` 开始的数字替代。"
1235 ]
1236 },
1237 {
1238 "cell_type": "markdown",
1239 "metadata": {},
1240 "source": [
1241 "写入 `csv` 文件:"
1242 ]
1243 },
1244 {
1245 "cell_type": "code",
1246 "execution_count": null,
1247 "metadata": {},
1248 "outputs": [],
1249 "source": [
1250 "df.to_csv('foo.csv')"
1251 ]
1252 },
1253 {
1254 "cell_type": "markdown",
1255 "metadata": {},
1256 "source": [
1257 "读取 `csv` 文件:"
1258 ]
1259 },
1260 {
1261 "cell_type": "code",
1262 "execution_count": null,
1263 "metadata": {},
1264 "outputs": [],
1265 "source": [
1266 "df1 = pd.read_csv('foo.csv',index_col=0)\n",
1267 "df1.head()"
1268 ]
1269 },
1270 {
1271 "cell_type": "markdown",
1272 "metadata": {},
1273 "source": [
1274 "`head` 和 `tail` 方法可以分别查看最前面几行和最后面几行的数据(默认为 `5`):"
1275 ]
1276 },
1277 {
1278 "cell_type": "code",
1279 "execution_count": null,
1280 "metadata": {},
1281 "outputs": [],
1282 "source": [
1283 "df1.tail(10)"
1284 ]
1285 },
1286 {
1287 "cell_type": "markdown",
1288 "metadata": {},
1289 "source": [
1290 "了解更多`Pandas`内容,可以参考:https://pandas.pydata.org/pandas-docs/stable/getting_started/index.html"
1291 ]
1292 },
1293 {
1294 "cell_type": "markdown",
1295 "metadata": {},
1296 "source": [
1297 "## 3.3 `Matplotlib`"
1298 ]
1299 },
1300 {
1301 "cell_type": "markdown",
1302 "metadata": {},
1303 "source": [
1304 "<img src=\"https://matplotlib.org/_static/logo2.svg\" width=300/>"
1305 ]
1306 },
1307 {
1308 "cell_type": "markdown",
1309 "metadata": {},
1310 "source": [
1311 "简单来说,`Matplotlib` 是 **Python** 的一个绘图库。它包含了大量的工具,你可以使用这些工具创建各种图形,包括简单的散点图,正弦曲线,甚至是三维图形。\n",
1312 "\n",
1313 "**Python** 科学计算社区经常使用它完成数据可视化的工作。"
1314 ]
1315 },
1316 {
1317 "cell_type": "code",
1318 "execution_count": null,
1319 "metadata": {},
1320 "outputs": [],
1321 "source": [
1322 "%matplotlib inline\n",
1323 "\n",
1324 "import matplotlib.pyplot as plt\n",
1325 "import numpy as np"
1326 ]
1327 },
1328 {
1329 "cell_type": "markdown",
1330 "metadata": {},
1331 "source": [
1332 "### 画一个简单的图形"
1333 ]
1334 },
1335 {
1336 "cell_type": "code",
1337 "execution_count": null,
1338 "metadata": {},
1339 "outputs": [],
1340 "source": [
1341 "# 简单的绘图\n",
1342 "x = np.linspace(0, 2 * np.pi, 50)\n",
1343 "\n",
1344 "# 如果没有第一个参数 x,图形的 x 坐标默认为数组的索引\n",
1345 "plt.plot(x, np.sin(x)) \n",
1346 "\n",
1347 "# 显示图形\n",
1348 "plt.show() \n"
1349 ]
1350 },
1351 {
1352 "cell_type": "markdown",
1353 "metadata": {},
1354 "source": [
1355 "### 在一张图上绘制两条曲线"
1356 ]
1357 },
1358 {
1359 "cell_type": "code",
1360 "execution_count": null,
1361 "metadata": {},
1362 "outputs": [],
1363 "source": [
1364 "x = np.linspace(0, 2 * np.pi, 50)\n",
1365 "plt.plot(x, np.sin(x),\n",
1366 " x, np.cos(x))\n",
1367 "plt.show()"
1368 ]
1369 },
1370 {
1371 "cell_type": "markdown",
1372 "metadata": {},
1373 "source": [
1374 "### 自定义曲线的外观"
1375 ]
1376 },
1377 {
1378 "cell_type": "code",
1379 "execution_count": null,
1380 "metadata": {},
1381 "outputs": [],
1382 "source": [
1383 "x = np.linspace(0, 2 * np.pi, 50)\n",
1384 "plt.plot(x, np.sin(x), 'r-^',\n",
1385 " x, np.cos(x), 'g--')\n",
1386 "plt.show()\n"
1387 ]
1388 },
1389 {
1390 "cell_type": "markdown",
1391 "metadata": {},
1392 "source": [
1393 "- **颜色**: \n",
1394 " - 蓝色 - 'b' \n",
1395 " - 绿色 - 'g' \n",
1396 " - 红色 - 'r' \n",
1397 " - 青色 - 'c' \n",
1398 " - 品红 - 'm' \n",
1399 " - 黄色 - 'y' \n",
1400 " - 黑色 - 'k'('b'代表蓝色,所以这里用黑色的最后一个字母) \n",
1401 " - 白色 - 'w'\n",
1402 "\n",
1403 "- 线: \n",
1404 " - 直线 - '-' \n",
1405 " - 虚线 - '--' \n",
1406 " - 点线 - ':' \n",
1407 " - 点划线 - '-.'\n",
1408 "\n",
1409 "- 常用点标记:\n",
1410 " - 点 - '.' \n",
1411 " - 像素 - ',' \n",
1412 " - 圆 - 'o' \n",
1413 " - 方形 - 's' \n",
1414 " - 三角形 - '^' \n",
1415 " \n",
1416 "可以在[这里](http://matplotlib.org/api/markers_api.html)查看更多的样式"
1417 ]
1418 },
1419 {
1420 "cell_type": "markdown",
1421 "metadata": {},
1422 "source": [
1423 "### 使用子图"
1424 ]
1425 },
1426 {
1427 "cell_type": "markdown",
1428 "metadata": {},
1429 "source": [
1430 "使用子图可以在一个窗口绘制多张图。在调用 `plot()` 函数之前需要先调用 `subplot()` 函数。该函数的第一个参数代表子图的总行数,第二个参数代表子图的总列数,第三个参数代表活跃区域。"
1431 ]
1432 },
1433 {
1434 "cell_type": "code",
1435 "execution_count": null,
1436 "metadata": {},
1437 "outputs": [],
1438 "source": [
1439 "x = np.linspace(0, 2 * np.pi, 50)\n",
1440 "plt.subplot(2, 1, 1) # (行,列,活跃区)\n",
1441 "plt.plot(x, np.sin(x), 'r')\n",
1442 "plt.subplot(2, 1, 2)\n",
1443 "plt.plot(x, np.cos(x), 'g')\n",
1444 "plt.show()\n"
1445 ]
1446 },
1447 {
1448 "cell_type": "markdown",
1449 "metadata": {},
1450 "source": [
1451 "### 散点图"
1452 ]
1453 },
1454 {
1455 "cell_type": "markdown",
1456 "metadata": {},
1457 "source": [
1458 "散点图是一堆离散点的集合。用 `Matplotlib` 画散点图也同样非常简单。只需要调用 `scatter()` 函数并传入两个分别代表 `x` 坐标和 `y` 坐标的数组即可。"
1459 ]
1460 },
1461 {
1462 "cell_type": "code",
1463 "execution_count": null,
1464 "metadata": {},
1465 "outputs": [],
1466 "source": [
1467 "# 简单的散点图\n",
1468 "x = np.linspace(0, 2 * np.pi, 50)\n",
1469 "y = np.sin(x)\n",
1470 "plt.scatter(x,y)\n",
1471 "plt.show()\n"
1472 ]
1473 },
1474 {
1475 "cell_type": "markdown",
1476 "metadata": {},
1477 "source": [
1478 "### 调整点的大小和颜色"
1479 ]
1480 },
1481 {
1482 "cell_type": "markdown",
1483 "metadata": {},
1484 "source": [
1485 "可以给每个点赋予不同的大小"
1486 ]
1487 },
1488 {
1489 "cell_type": "code",
1490 "execution_count": null,
1491 "metadata": {},
1492 "outputs": [],
1493 "source": [
1494 "x = np.random.rand(100)\n",
1495 "y = np.random.rand(100)\n",
1496 "size = np.random.rand(100) * 50\n",
1497 "plt.scatter(x, y, size)\n",
1498 "plt.show()\n"
1499 ]
1500 },
1501 {
1502 "cell_type": "markdown",
1503 "metadata": {},
1504 "source": [
1505 "也可以给每个点赋予不同颜色。"
1506 ]
1507 },
1508 {
1509 "cell_type": "code",
1510 "execution_count": null,
1511 "metadata": {},
1512 "outputs": [],
1513 "source": [
1514 "x = np.random.rand(100)\n",
1515 "y = np.random.rand(100)\n",
1516 "size = np.random.rand(100) * 50\n",
1517 "color = np.random.rand(100)\n",
1518 "plt.scatter(x, y, size, color)\n",
1519 "plt.colorbar()\n",
1520 "plt.show()\n"
1521 ]
1522 },
1523 {
1524 "cell_type": "markdown",
1525 "metadata": {},
1526 "source": [
1527 "### 直方图"
1528 ]
1529 },
1530 {
1531 "cell_type": "markdown",
1532 "metadata": {},
1533 "source": [
1534 "使用 `hist()` 函数可以非常方便的创建直方图。第二个参数代表分段的个数。分段越多,图形上的数据条就越多。"
1535 ]
1536 },
1537 {
1538 "cell_type": "code",
1539 "execution_count": null,
1540 "metadata": {},
1541 "outputs": [],
1542 "source": [
1543 "x = np.random.randn(1000)\n",
1544 "plt.hist(x, 50)\n",
1545 "plt.show()\n"
1546 ]
1547 },
1548 {
1549 "cell_type": "markdown",
1550 "metadata": {},
1551 "source": [
1552 "### 标题,标签和图例"
1553 ]
1554 },
1555 {
1556 "cell_type": "markdown",
1557 "metadata": {},
1558 "source": [
1559 "当需要快速创建图形时,你可能不需要为图形添加标签。但是当构建需要展示的图形时,你就需要添加标题,标签和图例。"
1560 ]
1561 },
1562 {
1563 "cell_type": "code",
1564 "execution_count": null,
1565 "metadata": {},
1566 "outputs": [],
1567 "source": [
1568 "x = np.linspace(0, 2 * np.pi, 50)\n",
1569 "plt.plot(x, np.sin(x), 'r-x', label='Sin(x)')\n",
1570 "plt.plot(x, np.cos(x), 'g-^', label='Cos(x)')\n",
1571 "\n",
1572 "# 展示图例\n",
1573 "plt.legend()\n",
1574 "\n",
1575 "# 给 x 轴添加标签\n",
1576 "plt.xlabel('Rads')\n",
1577 "\n",
1578 "# 给 y 轴添加标签\n",
1579 "plt.ylabel('Amplitude')\n",
1580 "\n",
1581 "# 添加图形标题\n",
1582 "plt.title('Sin and Cos Waves')\n",
1583 "\n",
1584 "plt.show()\n"
1585 ]
1586 },
1587 {
1588 "cell_type": "markdown",
1589 "metadata": {},
1590 "source": [
1591 "### 图片保存"
1592 ]
1593 },
1594 {
1595 "cell_type": "code",
1596 "execution_count": null,
1597 "metadata": {},
1598 "outputs": [],
1599 "source": [
1600 "fruits = ['apple', 'orange', 'pear']\n",
1601 "sales = [100,250,300]\n",
1602 "plt.pie(sales, labels=fruits)\n",
1603 "plt.savefig('pie.jpg')\n",
1604 "plt.show()\n"
1605 ]
1606 },
1607 {
1608 "cell_type": "markdown",
1609 "metadata": {},
1610 "source": [
1611 "可以在这里查看更多的[图例](https://matplotlib.org/gallery.html)。"
1612 ]
1613 },
1614 {
1615 "cell_type": "markdown",
1616 "metadata": {},
1617 "source": [
1618 "### Seaborn"
1619 ]
1620 },
1621 {
1622 "cell_type": "markdown",
1623 "metadata": {},
1624 "source": [
1625 "`Seaborn` 基于 `matplotlib`, 可以快速的绘制一些统计图表。"
1626 ]
1627 },
1628 {
1629 "cell_type": "code",
1630 "execution_count": null,
1631 "metadata": {},
1632 "outputs": [],
1633 "source": [
1634 "import seaborn as sns\n",
1635 "import pandas as pd\n",
1636 "sns.set()\n",
1637 "iris = pd.read_csv(\"iris.csv\")\n",
1638 "sns.jointplot(x=\"sepal_length\", y=\"petal_length\", data=iris)"
1639 ]
1640 },
1641 {
1642 "cell_type": "code",
1643 "execution_count": null,
1644 "metadata": {},
1645 "outputs": [],
1646 "source": [
1647 "sns.pairplot(data=iris, hue=\"species\")\n"
1648 ]
1649 },
1650 {
1651 "cell_type": "markdown",
1652 "metadata": {},
1653 "source": [
1654 "可以在这里查看更多的[示例](https://seaborn.pydata.org/tutorial.html)。"
1655 ]
1656 },
1657 {
1658 "cell_type": "markdown",
1659 "metadata": {},
1660 "source": [
1661 "## 3.4 `Scikit-learn`"
1662 ]
1663 },
1664 {
1665 "cell_type": "markdown",
1666 "metadata": {},
1667 "source": [
1668 "<img src=\"http://imgbed.momodel.cn/scikitlearn.png\" width=300 />\n",
1669 "\n",
1670 "+ **Python** 语言的机器学习工具\n",
1671 "+ `Scikit-learn` 包括许多知名的机器学习算法的实现\n",
1672 "+ `Scikit-learn` 文档完善,容易上手,丰富的 `API`"
1673 ]
1674 },
1675 {
1676 "cell_type": "markdown",
1677 "metadata": {},
1678 "source": [
1679 "### 机器学习算法"
1680 ]
1681 },
1682 {
1683 "cell_type": "markdown",
1684 "metadata": {},
1685 "source": [
1686 "**机器学习算法是一类从数据中自动分析获得规律,并利用规律对未知数据进行预测的算法**。\n"
1687 ]
1688 },
1689 {
1690 "cell_type": "markdown",
1691 "metadata": {},
1692 "source": [
1693 "<img src=\"http://imgbed.momodel.cn/q2nay75zew.png\" width=800>\n",
1694 "\n",
1695 "由图中,可以看到机器学习 `sklearn` 库的算法主要有四类:分类,回归,聚类,降维。其中:\n",
1696 "\n",
1697 "+ 常用的回归:线性、决策树、`SVM`、`KNN` ; \n",
1698 " 集成回归:随机森林、`Adaboost`、`GradientBoosting`、`Bagging`、`ExtraTrees` \n",
1699 "+ 常用的分类:线性、决策树、`SVM`、`KNN`,朴素贝叶斯; \n",
1700 " 集成分类:随机森林、`Adaboost`、`GradientBoosting`、`Bagging`、`ExtraTrees` \n",
1701 "+ 常用聚类:`k` 均值(`K-means`)、层次聚类(`Hierarchical clustering`)、`DBSCAN` \n",
1702 "+ 常用降维:`LinearDiscriminantAnalysis`、`PCA`   \n",
1703 "\n",
1704 "这个流程图代表:蓝色圆圈是判断条件,绿色方框是可以选择的算法,我们可以根据自己的数据特征和任务目标去找一条自己的操作路线。 "
1705 ]
1706 },
1707 {
1708 "cell_type": "markdown",
1709 "metadata": {},
1710 "source": [
1711 "### `sklearn` 数据集"
1712 ]
1713 },
1714 {
1715 "cell_type": "markdown",
1716 "metadata": {},
1717 "source": [
1718 "+ `sklearn.datasets.load_*()`\n",
1719 " + 获取小规模数据集,数据包含在 `datasets` 里\n",
1720 "+ `sklearn.datasets.fetch_*(data_home=None)`\n",
1721 " + 获取大规模数据集,需要从网络上下载,函数的第一个参数是 `data_home`,表示数据集下载的目录,默认是 `/scikit_learn_data/`\n",
1722 " \n",
1723 "`sklearn` 常见的数据集如下:\n",
1724 "\n",
1725 "||数据集名称|调用方式|适用算法|数据规模|\n",
1726 "|--|--|--|--|--|\n",
1727 "|小数据集|波士顿房价|load_boston()|回归|506\\*13|\n",
1728 "|小数据集|鸢尾花数据集|load_iris()|分类|150\\*4|\n",
1729 "|小数据集|糖尿病数据集|\tload_diabetes()|\t回归\t|442\\*10|\n",
1730 "|大数据集|手写数字数据集|\tload_digits()|\t分类|\t5620\\*64|\n",
1731 "|大数据集|Olivetti脸部图像数据集|\tfetch_olivetti_facecs|\t降维|\t400\\*64\\*64|\n",
1732 "|大数据集|新闻分类数据集|\tfetch_20newsgroups()|\t分类|-|\t \n",
1733 "|大数据集|带标签的人脸数据集|\tfetch_lfw_people()|\t分类、降维|-|\t \n",
1734 "|大数据集|路透社新闻语料数据集|\tfetch_rcv1()|\t分类|\t804414\\*47236|"
1735 ]
1736 },
1737 {
1738 "cell_type": "code",
1739 "execution_count": null,
1740 "metadata": {},
1741 "outputs": [],
1742 "source": [
1743 "from sklearn.datasets import load_iris\n",
1744 "# 获取鸢尾花数据集\n",
1745 "iris = load_iris()\n",
1746 "print(\"鸢尾花数据集的返回值:\\n\", iris.keys())"
1747 ]
1748 },
1749 {
1750 "cell_type": "markdown",
1751 "metadata": {},
1752 "source": [
1753 "### 数据预处理\n",
1754 "\n",
1755 "通过**一些转换函数**将特征数据转换成**更加适合算法模型**的特征数据过程。常见的有数据标准化、数据二值化、标签编码、独热编码等。"
1756 ]
1757 },
1758 {
1759 "cell_type": "code",
1760 "execution_count": null,
1761 "metadata": {},
1762 "outputs": [],
1763 "source": [
1764 "# 导入内建数据集\n",
1765 "from sklearn.datasets import load_iris\n",
1766 "\n",
1767 "# 获取鸢尾花数据集\n",
1768 "iris = load_iris()\n",
1769 "\n",
1770 "# 获得ndarray格式的变量X和标签y\n",
1771 "X = iris.data\n",
1772 "y = iris.target\n",
1773 "\n",
1774 "# 获得数据维度\n",
1775 "n_samples, n_features = iris.data.shape\n",
1776 "\n",
1777 "print(n_samples, n_features)"
1778 ]
1779 },
1780 {
1781 "cell_type": "markdown",
1782 "metadata": {},
1783 "source": [
1784 "#### 数据标准化\n",
1785 "\n",
1786 "数据标准化和归一化是将数据映射到一个小的浮点数范围内,以便模型能快速收敛。\n",
1787 "\n",
1788 "标准化有多种方式,常用的一种是min-max标准化(对象名为MinMaxScaler),该方法使数据落到[0,1]区间:\n",
1789 "\n",
1790 "$x^{'}=\\frac{x-x_{min}}{x_{max} - x_{min}}$"
1791 ]
1792 },
1793 {
1794 "cell_type": "code",
1795 "execution_count": null,
1796 "metadata": {},
1797 "outputs": [],
1798 "source": [
1799 "# min-max标准化\n",
1800 "from sklearn.preprocessing import MinMaxScaler\n",
1801 "\n",
1802 "sc = MinMaxScaler()\n",
1803 "sc.fit(X)\n",
1804 "results = sc.transform(X)\n",
1805 "print(\"放缩前:\",X[1])\n",
1806 "print(\"放缩后:\",results[1])\n"
1807 ]
1808 },
1809 {
1810 "cell_type": "markdown",
1811 "metadata": {},
1812 "source": [
1813 "另一种是Z-score标准化(对象名为StandardScaler),该方法使数据满足标准正态分布:\n",
1814 "\n",
1815 "$x^{'}=\\frac{x-\\overline {X}}{S}$"
1816 ]
1817 },
1818 {
1819 "cell_type": "code",
1820 "execution_count": null,
1821 "metadata": {},
1822 "outputs": [],
1823 "source": [
1824 "# Z-score标准化\n",
1825 "from sklearn.preprocessing import StandardScaler\n",
1826 "\n",
1827 "#将fit和transform组合执行\n",
1828 "results = StandardScaler().fit_transform(X) \n",
1829 "\n",
1830 "print(\"放缩前:\",X[1])\n",
1831 "print(\"放缩后:\",results[1])"
1832 ]
1833 },
1834 {
1835 "cell_type": "markdown",
1836 "metadata": {},
1837 "source": [
1838 "归一化(对象名为Normalizer,默认为L2归一化):\n",
1839 "\n",
1840 "$x^{'}=\\frac{x}{\\sqrt{\\sum_{j}^{m}x_{j}^2}}$"
1841 ]
1842 },
1843 {
1844 "cell_type": "code",
1845 "execution_count": null,
1846 "metadata": {},
1847 "outputs": [],
1848 "source": [
1849 "# 归一化\n",
1850 "from sklearn.preprocessing import Normalizer\n",
1851 "\n",
1852 "results = Normalizer().fit_transform(X)\n",
1853 "\n",
1854 "print(\"放缩前:\",X[1])\n",
1855 "print(\"放缩后:\",results[1])"
1856 ]
1857 },
1858 {
1859 "cell_type": "markdown",
1860 "metadata": {},
1861 "source": [
1862 "#### 数据二值化\n",
1863 "\n",
1864 "使用阈值过滤器将数据转化为布尔值,即为二值化。使用Binarizer对象实现数据的二值化:"
1865 ]
1866 },
1867 {
1868 "cell_type": "code",
1869 "execution_count": null,
1870 "metadata": {},
1871 "outputs": [],
1872 "source": [
1873 "# 二值化,阈值设置为3\n",
1874 "from sklearn.preprocessing import Binarizer\n",
1875 "\n",
1876 "results = Binarizer(threshold=3).fit_transform(X)\n",
1877 "\n",
1878 "print(\"处理前:\",X[1])\n",
1879 "print(\"处理后:\",results[1])"
1880 ]
1881 },
1882 {
1883 "cell_type": "markdown",
1884 "metadata": {},
1885 "source": [
1886 "#### 标签编码\n",
1887 "\n",
1888 "使用 LabelEncoder 将不连续的数值或文本变量转化为有序的数值型变量:\n"
1889 ]
1890 },
1891 {
1892 "cell_type": "code",
1893 "execution_count": null,
1894 "metadata": {},
1895 "outputs": [],
1896 "source": [
1897 "# 标签编码\n",
1898 "from sklearn.preprocessing import LabelEncoder\n",
1899 "LabelEncoder().fit_transform(['apple','pear','orange','banana'])"
1900 ]
1901 },
1902 {
1903 "cell_type": "markdown",
1904 "metadata": {},
1905 "source": [
1906 "#### 独热编码\n",
1907 "\n",
1908 "对于无序的离散型特征,其数值大小并没有意义,需要对其进行one-hot编码,将其特征的m个可能值转化为m个二值化特征。可以利用OneHotEncoder对象实现:"
1909 ]
1910 },
1911 {
1912 "cell_type": "code",
1913 "execution_count": null,
1914 "metadata": {},
1915 "outputs": [],
1916 "source": [
1917 "# 独热编码\n",
1918 "from sklearn.preprocessing import OneHotEncoder\n",
1919 "\n",
1920 "results = OneHotEncoder().fit_transform(y.reshape(-1,1)).toarray()\n",
1921 "\n",
1922 "print(\"处理前:\",X[1])\n",
1923 "print(\"处理后:\",results[1])\n"
1924 ]
1925 },
1926 {
1927 "cell_type": "markdown",
1928 "metadata": {},
1929 "source": [
1930 "### 数据集的划分\n",
1931 "\n",
1932 "机器学习一般的数据集会划分为两个部分:\n",
1933 "+ 训练数据:用于训练,构建模型\n",
1934 "+ 测试数据:在模型检验时使用,用于评估模型是否有效\n",
1935 "\n",
1936 "<br>\n",
1937 "\n",
1938 "划分比例:\n",
1939 "+ 训练集:70% 80% 75%\n",
1940 "+ 测试集:30% 20% 30%\n",
1941 "\n",
1942 "<br>\n",
1943 "`sklearn.model_selection.train_test_split(arrays, *options)`\n",
1944 " + `x`:数据集的特征值\n",
1945 " + `y`: 数据集的标签值\n",
1946 " + `test_size`: 如果是浮点数,表示测试集样本占比;如果是整数,表示测试集样本的数量。\n",
1947 " + `random_state`: 随机数种子,不同的种子会造成不同的随机采样结果。相同的种子采样结果相同。\n",
1948 " + `return` 训练集的特征值 `x_train` 测试集的特征值 `x_test` 训练集的目标值 `y_train` 测试集的目标值 `y_test`。\n"
1949 ]
1950 },
1951 {
1952 "cell_type": "code",
1953 "execution_count": null,
1954 "metadata": {},
1955 "outputs": [],
1956 "source": [
1957 "from sklearn.datasets import load_iris\n",
1958 "from sklearn.model_selection import train_test_split\n",
1959 "\n",
1960 "# 加载数据集\n",
1961 "iris = load_iris()\n",
1962 "\n",
1963 "# 对数据集进行分割\n",
1964 "# 训练集的特征值x_train 测试集的特征值x_test 训练集的目标值y_train 测试集的目标值y_test\n",
1965 "X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target,test_size=0.3, random_state=22)\n",
1966 "\n",
1967 "print(\"x_train:\", X_train.shape)\n",
1968 "print(\"y_train:\", y_train.shape)\n",
1969 "print(\"x_test:\", X_test.shape)\n",
1970 "print(\"y_test:\", y_test.shape)\n"
1971 ]
1972 },
1973 {
1974 "cell_type": "markdown",
1975 "metadata": {},
1976 "source": [
1977 "### 定义模型"
1978 ]
1979 },
1980 {
1981 "cell_type": "markdown",
1982 "metadata": {},
1983 "source": [
1984 "#### 估计器(`Estimator`)\n",
1985 "估计器,很多时候可以直接理解成分类器,主要包含两个函数:\n",
1986 "\n",
1987 "+ `fit()`:训练算法,设置内部参数。接收训练集和类别两个参数。\n",
1988 "+ `predict()`:预测测试集类别,参数为测试集。\n",
1989 "\n",
1990 "大多数 `scikit-learn` 估计器接收和输出的数据格式均为 `NumPy`数组或类似格式。\n",
1991 "\n",
1992 "<br>\n",
1993 "\n",
1994 "#### 转换器(`Transformer`) \n",
1995 "转换器用于数据预处理和数据转换,主要是三个方法:\n",
1996 "\n",
1997 "+ `fit()`:训练算法,设置内部参数。\n",
1998 "+ `transform()`:数据转换。\n",
1999 "+ `fit_transform()`:合并 `fit` 和 `transform` 两个方法。\n",
2000 "\n",
2001 "<br>"
2002 ]
2003 },
2004 {
2005 "cell_type": "markdown",
2006 "metadata": {},
2007 "source": [
2008 "在 `scikit-learn` 中,所有模型都有同样的接口供调用。监督学习模型都具有以下的方法:\n",
2009 "+ `fit`:对数据进行拟合。\n",
2010 "+ `set_params`:设定模型参数。\n",
2011 "+ `get_params`:返回模型参数。\n",
2012 "+ `predict`:在指定的数据集上预测。\n",
2013 "+ `score`:返回预测器的得分。\n",
2014 "\n",
2015 "鸢尾花数据集是一个分类任务,故以决策树模型为例,采用默认参数拟合模型,并对验证集预测。"
2016 ]
2017 },
2018 {
2019 "cell_type": "code",
2020 "execution_count": null,
2021 "metadata": {},
2022 "outputs": [],
2023 "source": [
2024 "# 决策树分类器\n",
2025 "from sklearn.tree import DecisionTreeClassifier\n",
2026 "\n",
2027 "# 定义模型\n",
2028 "model = DecisionTreeClassifier()\n",
2029 "\n",
2030 "# 训练模型\n",
2031 "model.fit(X_train, y_train)\n",
2032 "\n",
2033 "# 在测试集上预测\n",
2034 "model.predict(X_test)\n",
2035 "\n",
2036 "# 测试集上的得分(默认为准确率)\n",
2037 "model.score(X_test, y_test)\n"
2038 ]
2039 },
2040 {
2041 "cell_type": "markdown",
2042 "metadata": {},
2043 "source": [
2044 "`scikit-learn` 中所有模型的调用方式都类似。"
2045 ]
2046 },
2047 {
2048 "cell_type": "markdown",
2049 "metadata": {},
2050 "source": [
2051 "### 模型评估\n",
2052 "\n",
2053 "评估模型的常用方法为 `K` 折交叉验证,它将数据集划分为 `K` 个大小相近的子集(`K` 通常取 `10`),每次选择其中(`K-1`)个子集的并集做为训练集,余下的做为测试集,总共得到 `K` 组训练集&测试集,最终返回这 `K` 次测试结果的得分,取其均值可作为选定最终模型的指标。"
2054 ]
2055 },
2056 {
2057 "cell_type": "code",
2058 "execution_count": null,
2059 "metadata": {},
2060 "outputs": [],
2061 "source": [
2062 "# 交叉验证\n",
2063 "from sklearn.model_selection import cross_val_score\n",
2064 "cross_val_score(model, X, y, scoring=None, cv=10)\n"
2065 ]
2066 },
2067 {
2068 "cell_type": "markdown",
2069 "metadata": {},
2070 "source": [
2071 "注意:由于之前采用了 `train_test_split` 分割数据集,它默认对数据进行了洗牌,所以这里可以直接使用 `cv=10` 来进行 `10` 折交叉验证(`cross_val_score` 不会对数据进行洗牌)。如果之前未对数据进行洗牌,则要搭配使用 `KFold` 模块:"
2072 ]
2073 },
2074 {
2075 "cell_type": "code",
2076 "execution_count": null,
2077 "metadata": {},
2078 "outputs": [],
2079 "source": [
2080 "from sklearn.model_selection import KFold\n",
2081 "n_folds = 10\n",
2082 "kf = KFold(n_folds, shuffle=True).get_n_splits(X)\n",
2083 "cross_val_score(model, X, y, scoring=None, cv = kf)\n"
2084 ]
2085 },
2086 {
2087 "cell_type": "markdown",
2088 "metadata": {},
2089 "source": [
2090 "### 保存与加载模型\n",
2091 "\n",
2092 "在训练模型后可将模型保存,以免下次重复训练。保存与加载模型使用 `sklearn` 的 `joblib`:"
2093 ]
2094 },
2095 {
2096 "cell_type": "code",
2097 "execution_count": null,
2098 "metadata": {},
2099 "outputs": [],
2100 "source": [
2101 "from sklearn.externals import joblib\n",
2102 "\n",
2103 "# 保存模型\n",
2104 "joblib.dump(model,'myModel.pkl')\n",
2105 "\n",
2106 "# 加载模型\n",
2107 "model=joblib.load('myModel.pkl')\n",
2108 "print(model)\n"
2109 ]
2110 },
2111 {
2112 "cell_type": "markdown",
2113 "metadata": {},
2114 "source": [
2115 "下面我们用一个小例子来展示如何使用 `sklearn` 工具包快速完成一个机器学习项目。"
2116 ]
2117 },
2118 {
2119 "cell_type": "markdown",
2120 "metadata": {},
2121 "source": [
2122 "### 采用逻辑回归模型实现鸢尾花分类\n",
2123 "\n",
2124 "\n",
2125 "**线性回归**\n",
2126 "\n",
2127 "在介绍逻辑回归之前先介绍一下线性回归,线性回归的主要思想是通过历史数据拟合出一条直线,因变量与自变量是线性关系,对新的数据用这条直线进行预测。 线性回归的公式如下:\n",
2128 "\n",
2129 "$y = w_{0}+w_{1}x_{1}+...+w_{n}x_{n}=w^{T}x+b$\n",
2130 "\n",
2131 "**逻辑回归**\n",
2132 "\n",
2133 "逻辑回归是一种广义的线性回归分析模型,是一种预测分析。虽然它名字里带回归,但实际上是一种分类学习方法。它不是仅预测出“类别”, 而是可以得到近似概率预测,这对于许多需要利用概率辅助决策的任务很有用。普遍应用于预测一个实例是否属于一个特定类别的概率,比如一封 `email` 是垃圾邮件的概率是多少。 因变量可以是二分类的,也可以是多分类的。因为结果是概率的,除了分类外还可以做 `ranking model`。逻辑的应用场景很多,如点击率预测(`CTR`)、天气预测、一些电商的购物搭配推荐、一些电商的搜索排序基线等。\n",
2134 "\n",
2135 "`sigmoid` **函数**\n",
2136 "\n",
2137 "`Sigmoid` 函数,呈现S型曲线,它将值转化为一个接近 `0` 或 `1` 的 `y` 值。 \n",
2138 "$y = g(z)=\\frac{1}{1+e^{-z}}$ 其中:$z = w^{T}x+b$ \n",
2139 "\n",
2140 "\n",
2141 "**鸢尾花数据集**\n",
2142 "\n",
2143 "`sklearn.datasets.load_iris()`:加载并返回鸢尾花数据集\n",
2144 "\n",
2145 "`Iris` 鸢尾花卉数据集,是常用的分类实验数据集,由 `R.A. Fisher` 于 `1936` 年收集整理的。其中包含 `3` 种植物种类,分别是山鸢尾(`setosa`)变色鸢尾(`versicolor`)和维吉尼亚鸢尾(`virginica`),每类 `50` 个样本,共 `150` 个样本。 \n",
2146 "\n",
2147 "|变量名|\t变量解释|\t数据类型|\n",
2148 "|--|--|--|\n",
2149 "|sepal_length|\t花萼长度(单位cm)|\tnumeric|\n",
2150 "|sepal_width|\t花萼宽度(单位cm)|\tnumeric|\n",
2151 "|petal_length\t|花瓣长度(单位cm)|\tnumeric|\n",
2152 "|petal_width|\t花瓣宽度(单位cm)|\tnumeric|\n",
2153 "|species\t|种类\t|categorical|"
2154 ]
2155 },
2156 {
2157 "cell_type": "markdown",
2158 "metadata": {},
2159 "source": [
2160 "#### 1.获取数据集及其信息"
2161 ]
2162 },
2163 {
2164 "cell_type": "code",
2165 "execution_count": null,
2166 "metadata": {},
2167 "outputs": [],
2168 "source": [
2169 "from sklearn.datasets import load_iris\n",
2170 "# 获取鸢尾花数据集\n",
2171 "iris = load_iris()\n",
2172 "print(\"鸢尾花数据集的返回值:\\n\", iris.keys())"
2173 ]
2174 },
2175 {
2176 "cell_type": "code",
2177 "execution_count": null,
2178 "metadata": {},
2179 "outputs": [],
2180 "source": [
2181 "print(\"鸢尾花的特征值:\\n\", iris[\"data\"][1])\n",
2182 "print(\"鸢尾花的目标值:\\n\", iris.target)\n",
2183 "print(\"鸢尾花特征的名字:\\n\", iris.feature_names)\n",
2184 "print(\"鸢尾花目标值的名字:\\n\", iris.target_names)"
2185 ]
2186 },
2187 {
2188 "cell_type": "code",
2189 "execution_count": null,
2190 "metadata": {},
2191 "outputs": [],
2192 "source": [
2193 "# 取出特征值\n",
2194 "X = iris.data\n",
2195 "y = iris.target"
2196 ]
2197 },
2198 {
2199 "cell_type": "markdown",
2200 "metadata": {},
2201 "source": [
2202 "#### 2.数据划分"
2203 ]
2204 },
2205 {
2206 "cell_type": "code",
2207 "execution_count": null,
2208 "metadata": {},
2209 "outputs": [],
2210 "source": [
2211 "# 2.数据划分\n",
2212 "from sklearn.model_selection import train_test_split\n",
2213 "X_train,X_test,Y_train,Y_test = train_test_split(X, y, test_size=0.1, random_state=0)\n"
2214 ]
2215 },
2216 {
2217 "cell_type": "markdown",
2218 "metadata": {},
2219 "source": [
2220 "#### 3.数据标准化"
2221 ]
2222 },
2223 {
2224 "cell_type": "code",
2225 "execution_count": null,
2226 "metadata": {},
2227 "outputs": [],
2228 "source": [
2229 "from sklearn.preprocessing import StandardScaler\n",
2230 "transfer = StandardScaler()\n",
2231 "X_train = transfer.fit_transform(X_train)\n",
2232 "X_test = transfer.transform(X_test)"
2233 ]
2234 },
2235 {
2236 "cell_type": "markdown",
2237 "metadata": {},
2238 "source": [
2239 "#### 4.模型构建"
2240 ]
2241 },
2242 {
2243 "cell_type": "code",
2244 "execution_count": null,
2245 "metadata": {},
2246 "outputs": [],
2247 "source": [
2248 "from sklearn.linear_model import LogisticRegression\n",
2249 "\n",
2250 "estimator = LogisticRegression(penalty='l2',solver='newton-cg',multi_class='multinomial')\n",
2251 "estimator.fit(X_train,Y_train)"
2252 ]
2253 },
2254 {
2255 "cell_type": "markdown",
2256 "metadata": {},
2257 "source": [
2258 "#### 5.模型评估"
2259 ]
2260 },
2261 {
2262 "cell_type": "code",
2263 "execution_count": null,
2264 "metadata": {},
2265 "outputs": [],
2266 "source": [
2267 "# 5.模型评估\n",
2268 "print(\"\\n得出来的权重:\", estimator.coef_)\n",
2269 "print(\"\\nLogistic Regression模型训练集的准确率:%.1f%%\" %(estimator.score(X_train, Y_train)*100))"
2270 ]
2271 },
2272 {
2273 "cell_type": "markdown",
2274 "metadata": {},
2275 "source": [
2276 "#### 6. 模型预测"
2277 ]
2278 },
2279 {
2280 "cell_type": "code",
2281 "execution_count": null,
2282 "metadata": {},
2283 "outputs": [],
2284 "source": [
2285 "from sklearn import metrics\n",
2286 "y_predict = estimator.predict(X_test)\n",
2287 "print(\"\\n预测结果为:\\n\", y_predict)\n",
2288 "print(\"\\n比对真实值和预测值:\\n\", y_predict == Y_test)\n",
2289 "\n",
2290 "# 预测的准确率\n",
2291 "accuracy = metrics.accuracy_score(Y_test, y_predict)\n",
2292 "print(\"\\nLogistic Regression 模型测试集的正确率:%.1f%%\" %(accuracy*100))"
2293 ]
2294 },
2295 {
2296 "cell_type": "markdown",
2297 "metadata": {},
2298 "source": [
2299 "#### 7.交叉验证"
2300 ]
2301 },
2302 {
2303 "cell_type": "code",
2304 "execution_count": null,
2305 "metadata": {},
2306 "outputs": [],
2307 "source": [
2308 "from sklearn.model_selection import cross_val_score\n",
2309 "import numpy as np\n",
2310 "scores = cross_val_score(estimator, X, y, scoring=None, cv=10) #cv为迭代次数。\n",
2311 "print(\"\\n交叉验证的准确率:\",np.round(scores,2)) # 打印输出每次迭代的度量值(准确度)\n",
2312 "print(\"\\n交叉验证结果的置信区间: %0.2f%%(+/- %0.2f)\" % (scores.mean()*100, scores.std() * 2)) # 获取置信区间。(也就是均值和方差)\n"
2313 ]
2314 }
2315 ],
2316 "metadata": {
2317 "kernelspec": {
2318 "display_name": "Python 3",
2319 "language": "python",
2320 "name": "python3"
2321 },
2322 "language_info": {
2323 "codemirror_mode": {
2324 "name": "ipython",
2325 "version": 3
2326 },
2327 "file_extension": ".py",
2328 "mimetype": "text/x-python",
2329 "name": "python",
2330 "nbconvert_exporter": "python",
2331 "pygments_lexer": "ipython3",
2332 "version": "3.7.4"
2333 }
2334 },
2335 "nbformat": 4,
2336 "nbformat_minor": 2
2337 }
0 {
1 "cells": [
2 {
3 "cell_type": "markdown",
4 "metadata": {},
5 "source": [
6 "# Keras 基础介绍"
7 ]
8 },
9 {
10 "cell_type": "markdown",
11 "metadata": {},
12 "source": [
13 "<img src=\"https://s3.amazonaws.com/keras.io/img/keras-logo-2018-large-1200.png\" width=300>\n",
14 "\n",
15 "Keras 是一个用 Python 编写的高级神经网络 API,它能够以 TensorFlow,CNTK 或者 Theano 作为后端运行。\n",
16 "\n",
17 "Keras 具有如下优点:\n",
18 "- 由于用户友好,高度模块化,可扩展性,可以简单而快速的进行原型设计。\n",
19 "- 同时支持卷积神经网络和循环神经网络,以及两者的组合。\n",
20 "- 在 CPU 和 GPU 上无缝运行。"
21 ]
22 },
23 {
24 "cell_type": "markdown",
25 "metadata": {},
26 "source": [
27 "Keras 的核心数据结构是 model,一种组织网络层的方式。最简单的模型是 Sequential 顺序模型,它把多个网络层线性堆叠起来。"
28 ]
29 },
30 {
31 "cell_type": "markdown",
32 "metadata": {},
33 "source": [
34 "Sequential 模型如下所示:"
35 ]
36 },
37 {
38 "cell_type": "raw",
39 "metadata": {},
40 "source": [
41 "from keras.models import Sequential\n",
42 "\n",
43 "model = Sequential()"
44 ]
45 },
46 {
47 "cell_type": "markdown",
48 "metadata": {},
49 "source": [
50 "可以简单地使用 .add() 来堆叠模型:"
51 ]
52 },
53 {
54 "cell_type": "raw",
55 "metadata": {},
56 "source": [
57 "from keras.layers import Dense\n",
58 "\n",
59 "model.add(Dense(units=64, activation='relu', input_dim=100))\n",
60 "model.add(Dense(units=10, activation='softmax'))"
61 ]
62 },
63 {
64 "cell_type": "markdown",
65 "metadata": {},
66 "source": [
67 "在完成了模型的构建后, 可以使用 .compile() 来配置学习过程:"
68 ]
69 },
70 {
71 "cell_type": "raw",
72 "metadata": {},
73 "source": [
74 "model.compile(loss='categorical_crossentropy',\n",
75 " optimizer='sgd',\n",
76 " metrics=['accuracy'])"
77 ]
78 },
79 {
80 "cell_type": "markdown",
81 "metadata": {},
82 "source": [
83 "然后,就可以批量地在训练数据上进行迭代了。"
84 ]
85 },
86 {
87 "cell_type": "raw",
88 "metadata": {},
89 "source": [
90 "# x_train 和 y_train 是 Numpy 数组 -- 就像在 Scikit-Learn API 中一样。\n",
91 "model.fit(x_train, y_train, epochs=5, batch_size=32)"
92 ]
93 },
94 {
95 "cell_type": "markdown",
96 "metadata": {},
97 "source": [
98 "只需一行代码就能评估模型性能:"
99 ]
100 },
101 {
102 "cell_type": "raw",
103 "metadata": {},
104 "source": [
105 "loss_and_metrics = model.evaluate(x_test, y_test, batch_size=128)"
106 ]
107 },
108 {
109 "cell_type": "markdown",
110 "metadata": {},
111 "source": [
112 "或者对新的数据生成预测:"
113 ]
114 },
115 {
116 "cell_type": "raw",
117 "metadata": {},
118 "source": [
119 "classes = model.predict(x_test, batch_size=128)"
120 ]
121 },
122 {
123 "cell_type": "markdown",
124 "metadata": {},
125 "source": [
126 "# MNIST Playground\n"
127 ]
128 },
129 {
130 "cell_type": "markdown",
131 "metadata": {},
132 "source": [
133 "欢迎来到 MINST Playgound,在这里,你将了解到数据增强技术,并可以通过拖拽一个控制面板来创建深度学习模型,识别手写数字。无需编写任何代码,快来试试吧。\n",
134 "\n",
135 "*你需要按顺序运行下方每个 cell,直到看到控制面板,然后即可点击或拖拽来创建模型并训练。*"
136 ]
137 },
138 {
139 "cell_type": "code",
140 "execution_count": 1,
141 "metadata": {},
142 "outputs": [
143 {
144 "name": "stderr",
145 "output_type": "stream",
146 "text": [
147 "Using TensorFlow backend.\n"
148 ]
149 }
150 ],
151 "source": [
152 "# 导入必要的包\n",
153 "\n",
154 "!mkdir -p ~/.keras/datasets\n",
155 "!cp ./mnist.npz ~/.keras/datasets/mnist.npz\n",
156 "\n",
157 "\n",
158 "%matplotlib inline\n",
159 "import time\n",
160 "import matplotlib.pyplot as plt\n",
161 "import numpy as np\n",
162 "import imgaug\n",
163 "import imgaug.augmenters as iaa\n",
164 "from ipywidgets import interact, interactive, fixed, interact_manual\n",
165 "from ipywidgets import IntSlider, FloatSlider, Dropdown, Checkbox, Button, Output, \\\n",
166 " SelectMultiple,Image, HBox, VBox,SelectionSlider, HTML\n",
167 "from IPython import display\n",
168 "import tensorflow as tf\n",
169 "from tensorflow.keras.datasets import mnist\n",
170 "from tensorflow import keras\n",
171 "from tensorflow.python.util import deprecation\n",
172 "from keras.layers import Input, Dense,Flatten\n",
173 "from keras.models import Model\n",
174 "from keras.callbacks import Callback\n",
175 "from keras.utils.np_utils import to_categorical\n",
176 "from keras.utils import model_to_dot\n",
177 "deprecation._PRINT_DEPRECATION_WARNINGS = False\n",
178 "sometimes = lambda aug: iaa.Sometimes(0.5, aug)"
179 ]
180 },
181 {
182 "cell_type": "markdown",
183 "metadata": {},
184 "source": [
185 "## 处理数据\n",
186 "首先我们定义一个 MNIST 类,它将负责管理 MNIST 数据集"
187 ]
188 },
189 {
190 "cell_type": "code",
191 "execution_count": 2,
192 "metadata": {},
193 "outputs": [],
194 "source": [
195 "class MNIST:\n",
196 " \"\"\" 用来管理 MNIST 数据集的类\n",
197 " \n",
198 " Attributes:\n",
199 " x_train: MNIST 训练集的特征数据\n",
200 " y_train: MNIST 训练集的标签数据\n",
201 " x_test: MNIST 测试集的特征数据\n",
202 " y_test: MNIST 测试集的标签数据\n",
203 " train_size: MNIST 训练集的样本数\n",
204 " test_size: MNIST 测试集的样本数\n",
205 " augmented: 是否对训练集数据进行数据增强\n",
206 " seq: 进行数据增强的 pipline\n",
207 " \"\"\"\n",
208 " def __init__(self):\n",
209 " \"\"\"初始化 MNIST 类\n",
210 " \"\"\"\n",
211 " (self.x_train, self.y_train), (\n",
212 " self.x_test, self.y_test) = mnist.load_data()\n",
213 " self.x_train = self.x_train.reshape(self.x_train.shape[0], 28, 28, 1)\n",
214 " self.x_test = self.x_test.reshape(self.x_test.shape[0], 28, 28, 1)\n",
215 " self.x_train = self.x_train.astype('float32')\n",
216 " self.x_test = self.x_test.astype('float32')\n",
217 " self.train_size = self.x_train.shape[0]\n",
218 " self.test_size = self.x_test.shape[0]\n",
219 " self.y_train = to_categorical(self.y_train, 10)\n",
220 " self.y_test = to_categorical(self.y_test, 10)\n",
221 " self.noise = 0\n",
222 " self.rotation = 0\n",
223 " self.aug_possibility = 0 \n",
224 " self.seq = None\n",
225 " \n",
226 " def update_noise(self, noise):\n",
227 " \"\"\"数据增强中是否增加 noise\n",
228 " :param noise: 噪声比例\n",
229 " :return:\n",
230 " \"\"\"\n",
231 " self.noise = noise\n",
232 " self.update_aug_seq()\n",
233 " \n",
234 " def update_rotation(self, rotation):\n",
235 " \"\"\"数据增强中图片旋转的最大角度\n",
236 " :param rotation: 图片旋转的最大角度值\n",
237 " :return:\n",
238 " \"\"\"\n",
239 " self.rotation = rotation\n",
240 " self.update_aug_seq()\n",
241 " \n",
242 " def update_aug_possibility(self, possibility):\n",
243 " \"\"\"数据增强中图片旋转的最大角度\n",
244 " :param rotation: 图片旋转的最大角度值\n",
245 " :return:\n",
246 " \"\"\"\n",
247 " self.aug_possibility = possibility\n",
248 " self.update_aug_seq()\n",
249 " \n",
250 " def update_aug_seq(self):\n",
251 " \"\"\"更新数据增强的 pipline\n",
252 " :return:None\n",
253 " \"\"\"\n",
254 " if self.aug_possibility:\n",
255 " aug_seq = []\n",
256 " sometimes = lambda aug: iaa.Sometimes(self.aug_possibility, aug)\n",
257 " if self.noise:\n",
258 " aug_seq.append(sometimes(iaa.Salt(self.noise))) \n",
259 " if self.rotation:\n",
260 " aug_seq.append(sometimes(iaa.Affine(\n",
261 " rotate=(-self.rotation, self.rotation), \n",
262 " )))\n",
263 " self.seq = iaa.Sequential(aug_seq)\n",
264 " else:\n",
265 " self.seq = None\n",
266 "\n",
267 " def get_batch(self, batch_size, get_sample=False):\n",
268 " \"\"\"获取一个 batch 的训练数据(原始数据或者经过数据增强后的数据)\n",
269 " :param batch_size: 一个 batch 所包含的图片数目\n",
270 " :param get_sample: 是否返回固定样本来做展示\n",
271 " :return:\n",
272 " \"\"\"\n",
273 " while True:\n",
274 " if get_sample:\n",
275 " randidx = [i for i in range(batch_size)]\n",
276 " else:\n",
277 " randidx = np.random.randint(self.train_size, size=batch_size)\n",
278 " epoch_x = self.x_train[randidx]\n",
279 " epoch_y = self.y_train[randidx]\n",
280 " if self.aug_possibility and (self.noise or self.rotation):\n",
281 " epoch_x = self.seq(images=epoch_x)\n",
282 " epoch_x /= 255.0\n",
283 " yield epoch_x, epoch_y\n",
284 "\n",
285 " def get_batch_test(self, batch_size):\n",
286 " \"\"\"获取一个 batch 的测试数据\n",
287 " :param batch_size: 一个 batch 所包含的图片数目\n",
288 " :return:\n",
289 " \"\"\"\n",
290 " while True:\n",
291 " randidx = np.random.randint(self.test_size, size=batch_size)\n",
292 " epoch_x = self.x_test[randidx]\n",
293 " epoch_y = self.y_test[randidx]\n",
294 " epoch_x /= 255.0\n",
295 " yield epoch_x, epoch_y\n",
296 "\n",
297 " def plot_images(self, show=True):\n",
298 " \"\"\"绘制几个样本图片\n",
299 " :param show: 是否显示绘图\n",
300 " :return:\n",
301 " \"\"\"\n",
302 " sample_num = 9\n",
303 " imgs, _ = next(self.get_batch(sample_num,get_sample=True))\n",
304 " img_figure = plt.figure(1)\n",
305 " img_figure.set_figwidth(5)\n",
306 " img_figure.set_figheight(5)\n",
307 " for index in range(0, sample_num):\n",
308 " ax = plt.subplot(3, 3, index + 1)\n",
309 " ax.imshow(imgs[index].reshape(28, 28), cmap='gray')\n",
310 " plt.margins(0, 0)\n",
311 " img_figure.savefig('data_sample.jpg')\n",
312 " if not show:\n",
313 " plt.close(fig=img_figure)"
314 ]
315 },
316 {
317 "cell_type": "markdown",
318 "metadata": {},
319 "source": [
320 "## 数据增强技术"
321 ]
322 },
323 {
324 "cell_type": "code",
325 "execution_count": 3,
326 "metadata": {},
327 "outputs": [
328 {
329 "data": {
330 "image/png": "iVBORw0KGgoAAAANSUhEUgAAAToAAAEyCAYAAABqERwxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJzt3Xm0FMX1B/DvlQCiiAKaJ24sCkFUFNwQOUAUN1xwiagBBUPE4woeNeISl7iBJp7gLiqCygkhQQGNBAmCuCARDf4iiwIq8pRF3FhUCFq/P6bnvvseM8zW3dPd8/2c8867r2bpGq9TdHVVV4lzDkRESbZduStARBQ0NnRElHhs6Igo8djQEVHisaEjosRjQ0dEiceGjogSr6SGTkROFJEPRGSpiAzzq1JUOOYiGpiHaJJiJwyLSD0AHwI4DkA1gLcBnOecW+hf9SgfzEU0MA/R9bMSXnsEgKXOuY8AQETGA+gDIGtSRYS3YZRmrXNutwzlBeWCeSiZL3nwnsNclCZbLmoppeu6J4AV5u9qr6wWERksIvNEZF4Jx6KU5VnKc+aCefBV0XkAmAufZctFLaWc0eXFOTcKwCiA/3qVE/MQHcxF+Eo5o/sMwN7m7728MgofcxENzENEldLQvQ2grYi0FpEGAM4FMMWfalGBmItoYB4iquiuq3Nui4hcDmAagHoARjvnFvhWM8obcxENzEN0FT29pKiD8XpEqd5xzh1W6pswDyXzJQ8Ac+GDvHLBOyOIKPHY0BFR4rGhI6LEY0NHRIkX+IRholIceuihGl9++eUaX3DBBQCAp59+WsseeOABjd99990QakdxwTM6Iko8NnRElHgVOY+uXr16Gu+88845n2+7TDvssAMA4Be/+IWWXXbZZRr/8Y9/1Pi8887T+IcfftB4+PDhGt922235VhuokHl0hxxyiMavvPKKxk2aNNnm67799luNmzdv7n/FanAeXQGOPfZYjceNG6dxjx49NP7ggw+KfXvOoyMiAtjQEVEFSMyo6z777KNxgwYNNO7atSsAoFu3blq2yy67aHzWWWcVdbzq6mqN77//fo3POOMMjdevX6/xe++9p/Grr75a1DGT7IgjjtB44sSJGttLC/YyS/q/7ebNm7XMdle7dOmisR2Btc+Pm+7du2tsP+vzzz9fjurk7fDDD9f47bffLksdeEZHRInHho6IEi/WXddso3P5jKQW66effgIA3HTTTVq2YcMGje2o0sqVKzX++uuvNS5hhCn20qPWANC5c2eNn332WY1btGiR832WLFkCALjnnnu0bPz48Rq/8cYbGttc3X333QXWODp69uypcdu2bTWOYtd1u+1qzqFat26tccuWLTUWkfDqE9qRiIjKhA0dESVerLuun376qcZffvmlxsV2XefOnavxN998o/Evf/lLjdOjds8880xRx6h0jz32mMZ2QnWh0t3exo0ba5kdzbbdvI4dOxZ9nChJ398LAHPmzCljTXKzlx8uuugije0lisWLF4dWH57REVHisaEjosSLddf1q6++0vjaa6/V+JRTTtH4P//5D4Dak3qt+fPna3zcccdpvHHjRo0POOAAjYcMGVJCjStXermlk08+WcuyjbrZLugLL7ygsb2P+PPPPwdQk1+g9sj2Mccck/M4cWNHMqPuiSeeyFieHi0PW87/ciIyWkTWiMj7pqyZiEwXkSXe76bBVpM8rZiLSGAeYiaffyLGADixTtkwADOcc20BzPD+puCtBXMRBcxDzOTsujrnZotIqzrFfQD09OKxAGYBuM7HehVs0qRJGtvJw+l7Ig8++GAtGzRokMa2O2S7q9aCBTVbcw4ePLj0yhZvA4Cv6pRFLhdpdkL39OnTAdReasneuzp16lSN7WisXcrHTvxNd42++OILLbP3E6cndgO1u8t2knIJqxCHmof0qHFVVZUfbxeKbDMf0v8fhK3Ya3RVzrn0tP9VALJmQEQGAyhr65BweeWCeQgcvxMRVvJghHPObWvxQOfcKACjgPAWGVy3bt1WZXZRRsvO8fnrX/+qsT0jiItt5SKsPLRr105jO0CU/hd+7dq1WmZvkRs7dqzG9pa6f/zjHxnjQjRq1Ejjq6++WuN+/foV9X65+P2d6N27N4DanyOK7Bmnve3L+uyzz8KqTi3FDuOsFpEWAOD9XuNflahAzEU0MA8RVmxDNwXAAC8eAGCyP9WhIjAX0cA8RFjOrquI/AWpi6y7ikg1gFsADAcwQUQGAVgOoG+QlfTDrbfeqrHdQs9e7O7Vq5fGL7/8cij1KlBrAHMQsVw0bNhQYzu4k+5yATWDQvY2pnnz5mkcVrfMLtBaglDzYPcnSbMDZFFhc2+7sR9++KHGdjHaMOUz6prthsRjs5RTcD7OshEIcxEu5iFm4jPVmoioSLG+BawQdo6cHWm1c6kef/xxjWfOnKmx7WI99NBDAGrPAat0nTp10th2V60+ffoA4H4ZfinH3gt2DuSJJ9bMl+7fvz8A4Pjjj8/4uttvv11juypQmHhGR0SJx4aOiBKvYrqu1rJlyzQeOHCgxk899ZTG559/fsZ4xx13BAA8/fTTWmYnvlai++67T2O7UojtpobdZbUrfcRx8ncuzZo1K+j59hbIdI7sLIO99tpLY7tdqJ1Ubf+bfv/99xqnF6zdtGmTlv3sZzVNyzvvvFNQXYPAMzoiSjw2dESUeBXZdbXsVnF2UUDbHTv22JrpUXfddReA2tu23XnnnRqX616+sNnFTe0qJXY0esqUKaHWybLdVVsnu9BqXKS7ifZzPProoxrfcMMNOd/D7puR7rpu2bJFy7777juNFy5cqPHo0aM1trMP7KWI1atXAwCqq6u1zE4AD3NviGx4RkdEiceGjogSr+K7rtb77+vK2Ojbt+ZWxVNPPVXj9MjsxRdfrGV213S770SS2a6JHaVbs6Zm0Q677FVQ7H229n5myy7Eev311wddJd9deumlAIDly5drWdeuXQt6D7s1aHqR2kWLFmnZW2+9VXT90ovR7rbbblr20UcfFf1+QeAZHRElHhs6Iko8dl2zsPfkPfPMMxqn9yqwEyK7d++usd0hftasWcFVMKLspNEgJ1Knu6x2Hwm7orEdAfzTn/6ksV29OG5GjBhR7ipkZGclpE2cOLEMNcmOZ3RElHg8ozPsXKNf/epXGh9++OEa2zO5NDvvaPbs2QHVLh6CnDtn5+ulz97OOeccLZs8uWZR37POOiuwelBudn5qFPCMjogSjw0dESVeRXZd7Rr8l19+ucZnnnmmxrvvvvs23+PHH3/U2F50T+JKGZnYVUpsfPrpp2s8ZMiQko9z1VVXafz73/9e4/T2iePGjdMyux8FkZXzjE5E9haRmSKyUEQWiMgQr7yZiEwXkSXe76bBV7fitWMeIqE+vxPxkk/XdQuAq51zHQB0AXCZiHQAMAzADOdcWwAzvL8pWNXMQ2TwOxEj+ewCthLASi9eLyKLAOwJoA9S2yACwFgAswBcF0gtS5Dugp53Xs1mZra72qpVq4LeL72Cg12xJMRVOr4DopEHu5KGjW2X//7779c4vQrGl19+qWVdunTR2C5uaheJtAtC2tuYpk2bBgB4+OGHi/sApfmfc+5dIBq5iAp7CaNdu3Yal3J7mV8KGowQkVYAOgGYC6DKawQBYBWAqiwvI58xD9HBXMRD3oMRItIYwEQAQ51z62zr7ZxzIpJxWywRGQxgcKkVpRTmITqYi/jIq6ETkfpIJXScc+45r3i1iLRwzq0UkRYA1mR6rXNuFIBR3vsEtkeg3Rm8Q4cOGj/44IMAgPbt2xf0ful18AHg3nvv1Tg9KbVMo6uCiOehXr16GqdX3QBqJvCuW7dOy+yqL9m8+eabGtstKG+++eaS6lmqOHwnwmYvYdj9JaIgn1FXAfAkgEXOufvMQ1MADPDiAQAm130t+a4lmIeo4HciRvI5ozsawPkA/isi6XWobwAwHMAEERkEYDmAvlleT/5pDuAY5qHsGoPfiVjJZ9T1daS6TJlsvWxBwOw2b4899pjG9j7INm3a5P1+tmtkV7lIj+oBtbd2K7N3nHOHZSgPPQ9z5szR2O4ab+8LttKjsfYSg2VHY8ePH6+xH5OOA7DBOReZ70QUHXXUURqPGTOmfBXxRKsjTUQUADZ0RJR4kb3X9cgjj9TYLqh4xBFHaLznnnvm/X52Ozc7kTW9fSEAbNy4seB6Viq7sKW9R9jupWEXxcxk5MiRGj/yyCMaL1261I8qUsjs9Jqo4RkdESUeGzoiSrzIdl3POOOMjHE2dpXfF198UeP0buR2RNXuB0Gls8tU2S0Hs20/SMkydepUAMDZZ59d5ppkxzM6Iko8sbdtBH6wBN3uUibZ5tEVhHkomS95AJgLH+SVC57REVHisaEjosRjQ0dEiceGjogSjw0dESUeGzoiSjw2dESUeGzoiCjxwr4FbC2Ajd7vJNsVwXzGlj69z1qkVsANqp5REsRn9CsPAL8TpcorF6HeGQEAIjLPr1nlURWXzxiXepYiDp8xDnUsVbk/I7uuRJR4bOiIKPHK0dCNKsMxwxaXzxiXepYiDp8xDnUsVVk/Y+jX6IiIwsauKxElHhs6Ikq8UBs6ETlRRD4QkaUiMizMYwdFRPYWkZkislBEFojIEK+8mYhMF5El3u+m5a5rGvMQHcxFSJxzofwAqAdgGYA2ABoAeA9Ah7COH+DnagGgsxfvBOBDAB0A3ANgmFc+DMCIcteVeYhOHpiLcHMR5hndEQCWOuc+cs5tBjAeQJ8Qjx8I59xK59y7XrwewCIAeyL12cZ6TxsL4PTy1HArzEN0MBchCbOh2xPACvN3tVeWGCLSCkAnAHMBVDnn0ttjrQJQVaZq1cU8RAdzERIORvhERBoDmAhgqHNunX3Mpc7VOY8nBMxDdEQpF2E2dJ8B2Nv8vZdXFnsiUh+phI5zzj3nFa8WkRbe4y0ArClX/epgHqKDuQhJSQ1dgSNGbwNoKyKtRaQBgHMBTCnl+FEgIgLgSQCLnHP3mYemABjgxQMATA64HvnmgnkIth78TkQkF7WEOWIEoDdSIzDLANxY7tEhP34AdEPqFPz/AMz3fnoDaA5gBoAlAP4FoFmAdSgoF8xDNPLAXASXi63qVMKHOQrANPP39QCuz/Eax5+Sfr7wIxcR+Bxx//ElD8xFcLmo+1NK1zWvESMRGSwi80RkXgnHopTlWcpz5oJ58FXReQCYC59ly0Utga8w7JwbBW/lAhFxQR+PMmMeooO5CF8pZ3SJHTGKIeYiGpiHiCqloUvkiFFMMRfRwDxEVNFdV+fcFhG5HMA0pEabRjvnFvhWM8obcxENzEN0hbrwJq9HlOwd58MGI8xDyXzJA8Bc+CCvXPAWMCJKPDZ0RJR4bOiIKPHY0BFR4rGhI6LEY0NHRInHho6IEi/we12T6qabbtL4tttu03i77Wr+7ejZs6fGr776aij1IgrKTjvtpHHjxo01PvnkkwEAu+22m5bdd1/NMnSbNm0KoXbbxjM6Iko8NnRElHjsuhZo4MCBAIDrrrtOy3766aeMzw3z9joiv7Rq1Upj+//5UUcdpfGBBx64zfdo0aKFxldeeaV/lSsSz+iIKPHY0BFR4rHrWqCWLVsCALbffvsy1yT+jjzySI379++vcY8ePTQ+4IADMr72mmuuAQB8/vnnWtatWzeNn332WY3nzp1bemUTqH379hoPHTpU4379+mncqFEjjVObe6WsWFGzYvz69esBAPvvv7+W9e3bV+OHH35Y48WLF5da7aLwjI6IEo8NHRElHruueejVq5fGV1xxxVaP29PxU045RePVq1cHW7EYOuecczQeOXKkxrvuuqvGtos0a9Ysje2E1HvvvXer97avs88999xzi69wQuy8884ajxgxAkDtXNjJwNksWbJE4xNOOEHj+vXrA6j9PbD5tHG58IyOiBKPDR0RJR67rlnYEbynnnpKY9sFSLPdqOXL89pPtyL87Gc1/3sddlhqWf/HH39cy3bYYQeNZ8+erfHtt9+u8euvv65xw4YNNZ4wYQIA4Pjjj8947HnzuDe0dcYZZ2j829/+Nu/XLVu2TOPjjjtOYzvqut9++5VYu+DlPKMTkdEiskZE3jdlzURkuogs8X43Dbaa5GnFXEQC8xAz+XRdxwA4sU7ZMAAznHNtAczw/qbgrQVzEQXMQ8zk7Lo652aLSKs6xX0A9PTisQBmAbgOCTJgwACN99hjj60et6OBTz/9dBhVAoANAL6qUxbZXNhJwE888cRWj0+fPl1jOwK4bt26jO9nn5Opy1pdXa3x2LFjC6tsYWKVBwA4++yzt/n4J598ovHbb7+tsb3X1XZXLTtROKqKvUZX5Zxb6cWrAFRle6KIDAYwuMjjUG555YJ5CBy/ExFW8mCEc85taxNe59woAKOA6G/Wa+f7/OY3v9HYrk7yzTffAADuuOOO8CqWp23lIqw82IGEG264wR4fQO3bgezipdnO4qwbb7xxm4/bVTK++OKL3JUNSBS/ExdddJHGgwen2tiXX35Zy5YuXarxmjVrCnrvqqqsbXpkFDu9ZLWItAAA73dh/2XIT8xFNDAPEVZsQzcFQPoi1gAAk/2pDhWBuYgG5iHCcnZdReQvSF1k3VVEqgHcAmA4gAkiMgjAcgB9s79DtNlFBidOnJjz+Q888AAAYObMmUFVaVtaA5iDiOXi5ptv1th2Vzdv3qzxtGnTANS+uP39999nfD+7MowddNhnn300Tt/uZS8hTJ4cWtsSyTxsi13l5dZbb/X1ve2CnFGVz6jreVkeOtbnulBuHzvnDstQzlyEi3mIGd4CRkSJV/G3gJ14Ys28z44dO2Z8zowZMzS2K25Usl122UXjSy+9VGO7T0a6uwoAp59++jbfz95GNG7cOI0PPfTQjM//+9//DgC455578qwx5cuOXu+44445n3/QQQdtVfbmm29qPGfOHH8qVgKe0RFR4rGhI6LEq8iuq+1GDR8+PONz7KoZ9nawb7/9NriKxUiDBg00zrawou0C/fznPwcAXHjhhVp22mmnaWy3z7O7wNuusI3Te0Js3Lix4LpXuvSqMR06dNCyW265RePevXtnfN1229WcF2Xa4tOO7No8//jjj8VX1ic8oyOixGNDR0SJVzFd10InBn/00Ucac++HrdnJwPa+UrtXw8cff6yx7XZmYrs99r5Xu+P72rVrNX7hhRcKrHHlSe/lAACdOnXSOP3/v/1vaydv21zYEVM7Q8EumppmF1o988wzNbYzFez/N2HiGR0RJR4bOiJKvIrputp7LDONGNWVbTSWUtLLVQG1R7FffPFFjZs1a6Zxeu8Bez/qmDFjNP7qq5p1LMePH6+x7V7ZcsrMjobbruZzzz231XNvu+02jV955RWN33jjDY1tDu1z7Ch5mr1scffdd2v86aefajxp0iSNN23alOVT+I9ndESUeGzoiCjxEt91PeSQQwBk3xbPst2qDz74ILA6Jc3cuXM1tt2XQnTv3l3jHj16aGwvM9iRcKphR1dtd/Taa6/N+PypU6cCqFlyDKh9KcLm8KWXXtLY3tNqR0/T9xvb7myfPn00tvcu/+tf/9J4xIgRGn/99ddb1XP+/PkZ618MntERUeIl/owuvS5+06aZt9l86623NB44cGAYVaIMGjVqpLE9i7Pz7zgYUaNevXoa2306rrnmGo3t7XHDhtXsvpj+72jP4tIbjAPAgw8+qLGdf7dkyRKNL7nkEo3Ti9A2adJEy7p27apxv379NLa3/dld4Kz0bmOtW7fO+HgxeEZHRInHho6IEi/xXdfmzZsDyD53zm6/t2HDhlDqRFuzi3RSbuktC4Ha3dXvvvtO44svvlhju7Vhly5dANReYeSkk07S2F5G+MMf/qDxU089pXGmzaztrXv//Oc/M8bnnVezM8Ovf/3rrd4DAK666qqM5aXIeUYnInuLyEwRWSgiC0RkiFfeTESmi8gS73fmi2Dkp3bMQyTU53ciXvLpum4BcLVzrgOALgAuE5EOAIYBmOGcawtghvc3BauaeYgMfidiRHKtKrHVC0QmA3jQ++npnFvpbdg7yzn3ixyvDWVXcnuKnR5JzdZ1bdOmjcbLly8PtF4+eCe9+1Qc8lCIE044QWM7d8v+/2lvB7MrppTBO3YXsHLkYuXKlRrbeW/2tqrFixdrbPd+sPtzZGK3Q7S3ckVhAc0M3smyI1stBV2jE5FWADoBmAugyjmX/q+9CkBVltcMBjA402NUHOYhOpiLeMh71FVEGgOYCGCoc26dfcyl/tnN+C+Tc26Uc+6wfFpdyo15iA7mIj7yOqMTkfpIJXSccy69DMJqEWlhTtPXBFXJfKRv9QKAXr16aZzustpbVh566CGNY7aopiDieSiWvYQQB+X+TqxatUpj23Vt2LChxgcffHDG16YvDcyePVvL7Koin3zyicYR7a4WLJ9RVwHwJIBFzrn7zENTAKR3jRkAYHLd15LvWoJ5iAp+J2IknzO6owGcD+C/IpK+y/YGAMMBTBCRQQCWA+gbTBXJaA7gGOah7BqD34lYydnQOedeR6rLlMmx/laneHbn+N13332rxz/77DON7QTLmMk2whSZPBTrtdde0zjXtnoRsME5V9bvhF3txS582rlzZ43XrKnpOY8ePVrj9Eoh5dq/oRx4CxgRJR4bOiJKvMTf60rx8P7772tslwOyo7H77ruvxmWeMFx269ev1/iZZ57JGFMNntERUeKxoSOixEtM19Xe1/fmm29q3K1bt3JUh0pw1113afzEE09ofOedd2p8xRVXAAAWLlwYXsUotnhGR0SJV/DqJSUdLIKrZsRMXis15BL1PNi9ByZMmKCxvbUvvSGzXTzS7pEQMF/yAEQ/FzGQVy54RkdEiceGjogSj13XeKmIrqtlu7F2MCK93V7Hjh21LMSBCXZdo4NdVyIigA0dEVUAdl3jpeK6rhHFrmt0sOtKRASwoSOiChD2LWBrAWz0fifZrgjmM7b06X3WIrUCblD1jJIgPqNfeQD4nShVXrkI9RodAIjIvKTvfhSXzxiXepYiDp8xDnUsVbk/I7uuRJR4bOiIKPHK0dCNKsMxwxaXzxiXepYiDp8xDnUsVVk/Y+jX6IiIwsauKxElHhs6Ikq8UBs6ETlRRD4QkaUiMizMYwdFRPYWkZkislBEFojIEK+8mYhMF5El3u+m5a5rGvMQHcxFSJxzofwAqAdgGYA2ABoAeA9Ah7COH+DnagGgsxfvBOBDAB0A3ANgmFc+DMCIcteVeYhOHpiLcHMR5hndEQCWOuc+cs5tBjAeQJ8Qjx8I59xK59y7XrwewCIAeyL12cZ6TxsL4PTy1HArzEN0MBchCbOh2xPACvN3tVeWGCLSCkAnAHMBVDnnVnoPrQJQVaZq1cU8RAdzERIORvhERBoDmAhgqHNunX3Mpc7VOY8nBMxDdEQpF2E2dJ8B2Nv8vZdXFnsiUh+phI5zzj3nFa8WkRbe4y0ArClX/epgHqKDuQhJSQ1dgSNGbwNoKyKtRaQBgHMBTCnl+FEgIgLgSQCLnHP3mYemABjgxQMATA64HvnmgnkIth78TkQkF7WEOWIEoDdSIzDLANxY7tEhP34AdEPqFPz/AMz3fnoDaA5gBoAlAP4FoFmAdSgoF8xDNPLAXASXi63qVMKHOQrANPP39QCuz/Eax5+Sfr7wIxcR+Bxx//ElD8xFcLmo+1NK1zWvESMRGSwi80RkXgnHopTlWcpz5oJ58FXReQCYC59ly0Utga8w7JwbBW/lAm4EUj7MQ3QwF+Er5YwusSNGMcRcRAPzEFGlNHSJHDGKKeYiGpiHiCq66+qc2yIilwOYhtRo02jn3ALfakZ5Yy6igXmILm5gHS/cwDoauIF1dHADayIigA0dEVUANnRElHhs6Igo8QKfMBxFI0eO1PjKK6/U+P3339f4lFNO0Xj58rwmXxNRRPGMjogSjw0dESVexXRdW7VqpXH//v01/umnnzTef//9NW7fvr3G7Lr6p127dhrXr19f4+7du2v88MMPa2zzU4jJk2uWOjv33HM13rx5c1Hvl3Q2F127dtX4rrvu0vjoo48OtU5+4hkdESUeGzoiSryK6bp+8cUXGs+ePVvj0047rRzVqQgHHHCAxgMHDgQAnH322Vq23XY1/87uscceGtvuarG3KNq8PvrooxoPHTpU43Xrau3XUtF23nlnjWfOnKnxqlWrNN59990zlscBz+iIKPHY0BFR4lVM13Xjxo0acxQ1HHfffbfGvXv3Lls9LrjgAo2ffPJJjd94441yVCdWbHeVXVcioghjQ0dEiVcxXddddtlF44MPPriMNakc06dP1zhT13XNmpqN2m2X0o7GZpswnJ7U2qNHj5LrSdml9qKOP57REVHisaEjosSrmK7rDjvsoPE+++yT8/mHH364xosXL9aYI7b5e+SRRzSeNGnSVo//73//07jQUbwmTZoAqL20lp10bNljz5vHPaMLYSdsb7/99mWsSWlyntGJyGgRWSMi75uyZiIyXUSWeL+bBltN8rRiLiKBeYiZfLquYwCcWKdsGIAZzrm2AGZ4f1Pw1oK5iALmIWZydl2dc7NFpFWd4j4AenrxWACzAFznY7189/nnn2s8ZswYjW+99daMz7fl33zzjcYPPvig31UrxAYAX9Upi2wutmzZovGKFSt8fe8TTjgBANC0ae4Tp+rqao03bdrkx+FjlQe/HHZYza6Cb731VhlrUrhir9FVOedWevEqAFXZnigigwEMLvI4lFteuWAeAsfvRISVPBjhnHPb2oTXOTcKwCggOpv13n777RpnO6OLo23lIop5KJRdQPOiiy4CADRq1Cjn626++ebA6pRJHL8T9uz722+/1diuarLvvvuGWic/FTu9ZLWItAAA7/eaHM+n4DAX0cA8RFixDd0UAAO8eACAydt4LgWLuYgG5iHCcnZdReQvSF1k3VVEqgHcAmA4gAkiMgjAcgB9g6xkkPK53ShCWgOYg4TmIq1fv34aDxtWM3i53377aWz3OMhk/vz5Gtv5ej5JXB7sgNtrr72msd32M87yGXU9L8tDx/pcF8rtY+fcYRnKmYtwMQ9c6VTeAAAG9UlEQVQxw1vAiCjxKuYWsGz82J+AMrNbTJ5//vkAgF69euV8Xbdu3TTOlRO774Pt5r700ksaf//99zmPScnGMzoiSjw2dESUeBXfdSV/HXjggRpPmTJF43xWjCmGHSEcNWpUIMeglObNm5e7CkXjGR0RJR4bOiJKPHZdKTB2v4FC9h4oZBK3ndB60kknaTx16tS8j0f5Oe2008pdhaLxjI6IEo8NHRElXsV3XfPpJnXv3l3jMi+8GXl2D4eePXtq3L9/fwDAtGnTtOyHH34o6L0HDRqk8RVXXFFkDSmXmTNnapyUe115RkdEiceGjogST8K8vzMqq6laP/74o8b5/Lfo2LEjAGDhwoWB1Wkb3smyakZBopiHfNjVbr/88sutHj/11FM1DnjU1Zc8ANHMxVlnnaXx3/72N43tPcMdOnTQuMxbgOaVC57REVHiVfxgxKOPPqrxxRdfnPP5gwen9jQZOnRoYHWizNI7f1Gw7P4Rlp0L2bBhw7Cq4wue0RFR4rGhI6LEq/iu6+LFi8tdhViyezYcf/zxGr/yyisa+7Hg5YUXXqjxyJEjS34/ym3y5Jp9fez3o3379hrbSzeXXnppOBUrQc4zOhHZW0RmishCEVkgIkO88mYiMl1Elni/c2+ZTqVqxzxEQn1+J+Iln67rFgBXO+c6AOgC4DIR6QBgGIAZzrm2AGZ4f1OwqpmHyOB3IkYKnkcnIpMBPOj99HTOrfQ27J3lnPtFjtdGbs6Q9eGHH2qcbVfy9C1jduu9ZcuWBVuxGjpnqBx5sHs53HjjjRofd9xxGrdu3VrjFStW5P3ezZo107h3794aP/DAAxrvtNNOW73Odo/t6hr2NqYA1Jq7leTvxJ///GeN7WWEqqoqjQu9lc9nec2jK+ganYi0AtAJwFwAVc65ld5DqwBUZXnNYACDCzkObRvzEB3MRTzkPeoqIo0BTAQw1Dm3zj7mUqeFGf9lcs6Ncs4d5tdM8krHPEQHcxEfeZ3RiUh9pBI6zjn3nFe8WkRamNP0NUFVMiwLFizQuE2bNhmfk2shyIAJypgHu3KL3RvC+t3vfqfx+vXr835v2/3t3LmzxtkurcyaNQsA8Mgjj2hZwN3VWirlO2HZXGzevLmMNSlcPqOuAuBJAIucc/eZh6YAGODFAwBMrvta8l1LMA9Rwe9EjORzRnc0gPMB/FdE5ntlNwAYDmCCiAwCsBxA32CqSEZzAMcwD2XXGPxOxErOhs459zpSXaZMjvW3OuVlt8uzK2FESLYRpsjk4ZJLLvH1/dasqen9vfDCCxoPGTIEQNlG/DY45yriO2E1adJE4z59+mj8/PPPl6M6BeEtYESUeGzoiCjxKv5eV8suprlo0SKN999//3JUJ3IGDhyosd2zYcCAARmenZudaP3dd99p/Nprr2lsLyfY/SgoHH371lxm3LRpk8b2+xEHPKMjosRjQ0dEiceuq2HXvj/ooIPKWJNomj9/vsZ2aZ5///vfGt9xxx0aN21as3jHpEmTAADTp0/XMrsc0KpVq/ytLPli9uzZGttLOH4swRUmntERUeJV/C5gMVPRu4BFSKJ3AYsZ7gJGRASwoSOiCsCGjogSjw0dESUeGzoiSjw2dESUeGzoiCjx2NARUeKFfQvYWgAbvd9JtiuC+YwtfXqftUitgBtUPaMkiM/oVx4AfidKlVcuQr0zAgBEZF7Sdz+Ky2eMSz1LEYfPGIc6lqrcn5FdVyJKPDZ0RJR45WjoRuV+SuzF5TPGpZ6liMNnjEMdS1XWzxj6NToiorCx60pEiceGjogSL9SGTkROFJEPRGSpiAwL89hBEZG9RWSmiCwUkQUiMsQrbyYi00Vkife7aa73CgvzEB3MRUicc6H8AKgHYBmANgAaAHgPQIewjh/g52oBoLMX7wTgQwAdANwDYJhXPgzAiHLXlXmITh6Yi3BzEeYZ3REAljrnPnLObQYwHkCfEI8fCOfcSufcu168HsAiAHsi9dnGek8bC+D08tRwK8xDdDAXIQmzodsTwArzd7VXlhgi0gpAJwBzAVQ551Z6D60CUFWmatXFPEQHcxESDkb4REQaA5gIYKhzbp19zKXO1TmPJwTMQ3REKRdhNnSfAdjb/L2XVxZ7IlIfqYSOc8495xWvFpEW3uMtAKwpV/3qYB6ig7kISZgN3dsA2opIaxFpAOBcAFNCPH4gREQAPAlgkXPuPvPQFAADvHgAgMl1X1smzEN0MBdh1ckbAQnnYCK9AfwZqdGm0c65O0M7eEBEpBuA1wD8F8BPXvENSF2TmABgH6SWROrrnPuqLJWsg3mIRh4A5iKsXPAWMCJKPA5GEFHisaEjosRjQ0dEiceGjogSjw0dESUeGzoiSjw2dESUeP8PgT8jcWwmkVIAAAAASUVORK5CYII=\n",
331 "text/plain": [
332 "<Figure size 360x360 with 9 Axes>"
333 ]
334 },
335 "metadata": {
336 "needs_background": "light"
337 },
338 "output_type": "display_data"
339 }
340 ],
341 "source": [
342 "# 查看原始图片\n",
343 "mnist_data = MNIST()\n",
344 "mnist_data.plot_images()"
345 ]
346 },
347 {
348 "cell_type": "markdown",
349 "metadata": {},
350 "source": [
351 "数据增强技术就是利用有限的数据产生更多的等价数据。我们可以对原始的图片进行平移、旋转、缩放、加噪声等技术处理,来产生新的图片。\n",
352 "比如我们对一张小狗的图片进行旋转,加噪声,缩放等处理后,可以得到 6 张新的图片。\n",
353 "\n",
354 "使用数据增强技术可以拓展我们的数据集,提高模型的识别性能和泛化能力。\n",
355 "\n",
356 "<img src='https://miro.medium.com/max/3262/1*7udw5GZHwVo6u-V0lzglvQ.png' width=800 />\n",
357 "\n",
358 "下面我们对 mnist 数据集进行加噪声和旋转的处理。"
359 ]
360 },
361 {
362 "cell_type": "code",
363 "execution_count": 4,
364 "metadata": {},
365 "outputs": [
366 {
367 "data": {
368 "image/png": "iVBORw0KGgoAAAANSUhEUgAAAToAAAEyCAYAAABqERwxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAIABJREFUeJztnXeYFFXWxt8rElRUBBSHjAgiIhIVBRWBNSuuOeO3rBhR/EBFRDF8u4thDatrAGVlWdYICqwgImJGVLKABEVgyEEBAwru/f6YvmfeHqqmc3V1zfk9Dw8vt7urbs+ZKu6pc+45xloLRVGUKLNHviegKIqSa/RGpyhK5NEbnaIokUdvdIqiRB690SmKEnn0RqcoSuTRG52iKJEnoxudMeZUY8xiY8wyY8zAbE1KSR21RThQO4QTk27CsDGmEoAlAH4HoBjA5wAusdYuzN70lGRQW4QDtUN42TODzx4NYJm19hsAMMa8BKAnAF+jGmN0G0ZmbLLWHugxnpItkrFD69atAQBbtmyRseLi4tRnHKN+/foAgAMPLJ3+7NmzRbds2VL0woW5uy80b95c9JIlS9I9TFbsEHuPXhOZ4WeLODK50dUDsIr+XQzgmLJvMsb0AdAng/MopazwGU9oi1TtMGnSJADAiy++KGMDBgxI9uO70b9/fwBA7969ZWy//fYT/e9//1t0mzZt0j5PIp555hnR3bp1S/cwadsB0Gsiy/jZIo5MbnRJYa0dBmAYoP975RO1Q3hQWwRPJje61QAa0L/rx8aU4Mm6LerVq5fRhMriVnK8imNyuYr7/PPPRXfs2FE0P582xmTjVHpNhJRMoq6fA2hmjGlijKkC4GIA47MzLSVF1BbhQO0QUtJe0VlrdxljbgQwGUAlACOstQuyNjMladQW4UDtEF7STi9J62T6PCJTZlprO2R6kGzb4eOPPxbduXNnz/e437MsuYhpM3ToUNFHHXWU6K1btwIAatSoIWOnnnqq32GyYgdAr4kskJQtdGeEoiiRR290iqJEnpynlyjho02bNpg2bRoA4IADDkjrGBs3bhTNScDMzz//LDqRy7pjxw7R1apVS2tOfjzxxBOi+/btm9Yx/vvf/4reYw9dHxQaajFFUSJP5Fd0lSpVAgD89ttveZ5JeKhUqRL2339/AMDatWtlvKioqNzPJZN39uijj4rea6+9kp6T3yquffv2om+88UbRV155JQDgueeek7Fnn31W9KxZs0S3a9cu6Xls3rxZdK1atUTrKq6wUespihJ59EanKErkiYzr6lxUAKhdu7Zot5WJq3AcfPDBotkl+eabb0Tzg/Rdu3btpvfdd18Z27RpU0ZzD5otW7bIZv3LLrss6c9NmDDBc3zbtm2i/bZ4pQJvB5syZYpodkedG81FAs4//3zR7Hb65fYxbrM/f07JDt27dxc9evRo0SeeeKLoxYsX53QOuqJTFCXy6I1OUZTIExnXlSN87MK4LT9r1qyRMY40MuzGsnv7xRdfiJ45cyYAoHHjxjLGOWXbt28XPWzYMNHsEq1cubKcb5J7qlatikMPPbTc9zg39ayzzpKxs88+W/RXX30lOhl31UVE2dX88MMPPXXDhg1Fu+gwAHTt2lW0+zlz/l3VqlVFd+rUSTS7vL/++qvn/K699tqE3yHfnHDCCaL59+n111/Px3SShivGcCWZINEVnaIokUdvdIqiRJ7IuK6//PKLaHZjnavCvQJYJ0OzZs1EX3LJJbu9zlFX7oPACbYXXHCB6KZNm4qeOHEiAOBvf/ubjM2ZMyel+aXKV199hWOO2a3CdxzssnrRokWLlM55zTXXAABuueUWGeNE3n/961+iXX+J8nCRXleiHQBeeukl0VxRZfDgwaLZjTr33HNFjx07FgAwb948GePqJewK5wt23fl3MoyuKz8GatKkiehGjRqJDrKSja7oFEWJPHqjUxQl8kTGdd25c6doV5kDKN2/yQm+PXv2FH3QQQeJZleTW/uxW1y3bl0A8e7xp59+Knr+/PmieZ8mRxJ5ru44YUhU5STpVPappgLvR/V6DFCWcePGieaor0sE58ok/LPnpHHXuhGId1cZr/F77rkn4fyCxO3vBYDp06fncSaJ4X3TV199tWh+RMGR+1yjKzpFUSKP3ugURYk8kXFdGZfUCwDLly8HEO+6cq8ALhq5YEFpH5PbbrtNNCetur2hvH/yscce2+18QLz79PTTT4vmxFaX+Lps2bLyv1QWadKkCe677z4AwBVXXCHj2XBXOWGaf+ZunPc6+kXdXn31VdEzZswQzY8cXAL4mWeeKWOcxP3GG294nicHLQ4Do5BKRXH5LGbp0qUBz6SEhD85Y8wIY8wGY8yXNFbTGDPFGLM09nd6ZWqVVGmstggFaocCI5n/Il4AULYd0kAAU621zQBMjf1byT2boLYIA2qHAiOh62qt/cAY07jMcE8AXWN6JID3ANyexXllDVeeqWbNmjK2atUq0ZycyomlkydP9jye2/d60003ydjUqVM938tln5gqVaqIdmWfuCdBOfwAYEuZsZRtsXz58jiXNVn83FJm7733Fs3lltzP7Y9//KOMsRs5adIk0fweLuXD0W+XbMz7jOfOnSv6/fffF33GGWeI5ki4F8m0bkSW7JAsLmpcp06dbBwuEPhxD8Nlt4Ik3Wd0day1bmf8OgC+FjDG9AHQJ83zKIlJyhZqh5yj10SIyTgYYa215TXhtdYOAzAMyG+z3p9++kk0Vy/hldRhhx0mmvO3GBc04BVdqvhV0MiU8myRjB14VXXaaaft9jqv1vw4/PDDRQ8ZMkS0q7zB2+XYDt9++63oH374QfSbb76Z0vkdxx9/vGheOfKWMa+io8cdd1zS5/Aj29fE6aefDiB3uY3ZglecvO2LWb16dVDTiSPdMM56Y0wRAMT+3pC9KSkporYIB2qHEJPujW48gF4x3QuA9/JHCQK1RThQO4SYhK6rMeZFlDxkrW2MKQYwBMBQAK8YY3oDWAHgwlxOMhusW7dO9CeffCKaq3RwwIJz4ELUE6IJgOnIoi244ge7q++++y4AoFu3bjLGhUQZzgt8+OGHRTuXCygNZPA2Ji5o2q9fv1SmLYwaNUp0MgEW3orHfPDBBwCSzq3Luh3Kgx+pODjnMyyw7dmNXbJkiWgOaAVJMlFXvw2J3X3Gldyx3FrbwWNcbREsaocCo3BSrRVFUdIkklvAvODoKufyDBgwQPQ555wjmt3Vd955R/SXX5Ykw3MLxELm4osv9hxnl9Vx3XXXiV6xYoXoiy66SDS7q4zLu+Nj8M/7rrvuEs3bxBK1Y2R31fUHSQcXTeeKGlxcdPjw4aK5Gke+yEfvBe4NwkVJL7/8cgDAySef7Pm5+++/X/T333+fo9mVj67oFEWJPHqjUxQl8lQY15XhhFSuPPLkk0+Kdj0OgHgX5vbbS3b1bN68OZdTDAyu8sGupBeceMv6kUceEc1RS96G5eUKM1zVgvsh+PHbb78BiO+1wf0oOLn4hhtuSHi8P/zhD+W+HgZ3leEMgWTgij3ORj169JAx3grJWxT50QFXT+Eira7CDG/R23PP0lsLVxPKF7qiUxQl8uiNTlGUyFMhXVfu7u460gPxde779u0rmos7ukTNF154Qca+++67XEwzEE455ZSk38tuKf9MXnnlFdHjx48Xza6rqwriVxEkGXeVqVSpUrmv86MHdrNz3UoyFzg3kb/HM888I3rQoEEJj8F9M5wdOXOA94IvXLhQ9IgRI0Rzgjfbdv369QDi+6zwvtwge0P4oSs6RVEij97oFEWJPBXSdWU4gZFdMC4syRG5Xr1K9m1zbwgu0snRqEKgWrVqol988UXRiVoRsmvy0EMPiebo5NatW0WXU8SyXBL1eHjttddEc+SWo35cGJULeRYK119/PYD4JO1Uy0mtXLlStIu0L1q0SMa4ZWeq9OlTUlqP+6/4FZ3NF7qiUxQl8uiNTlGUyFPhXVd2h3h5P3LkSNHcc6Bt27YAgN69e8sYR6m4JE1YadiwIe644w4A8d//2Wef3e29HFE+4ADvxlZ33nmnaO7HcdVVV+323vfee090165dRXPiNpds8iub5EpDsfvlvhMQHwH861//Kvqtt94S7UozAcDs2bMBADfffLPn+bgd5aGHHur5nlzzwAMP5OW8iejeffeiLWPGjMnDTPzRFZ2iKJEn8iu6ypUrA4h/UMqBBu5WxNtkOMeIxx1dunQRXaNGDdG8TSbJzl6Bs3LlyrgqIg7OPXP4reL84Dw6L3gV9+CDD4rmVRyvips3by6au4rdeuutAOIrp3Cfj/POO08054il27Q6X6u4QuX111/P9xTi0BWdoiiRR290iqJEntC6ruxicC4V9yfYZ599RFevXl00V2I45phjAACHHHKIjB155JGiW7ZsKZpz4NgFZe3coI8++kjG+MF3WN3VZOAG0F7uOsP2YX3ppZeK9nqw76qOAP7buNhd5Yok3O9h586dAOJzGNldZfzaJH722Weijz766N1e5zxAv4bMSmGQcEVnjGlgjJlmjFlojFlgjLk5Nl7TGDPFGLM09ndqD3OUdGiudggFlfWaKCyScV13AehvrW0JoBOAG4wxLQEMBDDVWtsMwNTYv5XcUqx2CA16TRQQyXQBWwtgbUxvN8YsAlAPQE+UtEEEgJEA3gNwezqTqFevnmgXHeWtWbwdid0Qjox26tRJNBfKdFuVODLK7q8fXJyTO8q74oKuHSAArFmzJuHxssRPQOZ2aNOmDaZNmwYgPqrK7qpzDV3Uuix+RTg5wslFMV0VjMaNG8vYBRdcIJp7P/A8+DEE5zm6LUvcj2Djxo2i+XvxdjD3vYD47+a2/114YWmXwnLc1Z3W2llA7q6JQoQfYfDjh0y2l2WLlIIRxpjGANoCmAGgTuwmCADrANTx+ZiSZdQO4UFtURgkfaMzxlQHMAZAP2vtNn7NlvyXbn0+18cY84Ux5guv15XUyIYdQtSQu6DRa6JwSCrqaoypjBKDjrbWjo0NrzfGFFlr1xpjigBs8PqstXYYgGGx44jhueY9uwsuAbRRo0YyxpUymFQiYdzjgbc1sYs8ffp00atXrxa9ePHi3d6zZcuWpM+dRQyyZAevRGCOMvq5rIngxxCu6gZQ2gqP3dxkim1+8sknop27DQBnn302gHi3iJPC/fD7Xvw76AUXa61WrVpOrolCh23LmQphIJmoqwHwPIBF1tpH6KXxAHrFdC8A48p+Vsk6jaB2CAt6TRQQyazoOgO4AsB8Y4yrQz0IwFAArxhjegNYAaD8/xKVbFALQDe1Q96pDr0mCopkoq4focRl8mL3sgVJ8uuvv4resKF0he+Sdtm14iRcV5++7OfYZeKIqatQMW/ePBnjZF8uoMnn4SocvCTPMzOttR08xtO2A8NRbBfB5Cg3J2izm89d4zt27Oh5bOem+v0s+RiuQgwAHH/88Z7vHzx4MABg3bp1MsbR73bt2onmPhG8X5Z5+OGHAQADBgzwfJ0LlAL4wVqb9WsiShx77LGiub9KvgiXI60oipID9EanKErkydteV3YvR48e7akdXB+fWxLut99+ojkK+vXXX4t2LjK7qJw06rentqLAic8cXXYpKC66CcT3AeC9w+eee65oLvXEpZe4NJbj8ccfF/3000+L5iKXfqQS1WN3ld1zTiR2LisnHXMUlx9lNGjQIOlzVyTSLYEVBLqiUxQl8uiNTlGUyBPaMk0MJ40mQyruaEV0Vxl2AbnUlVc7vbp164r2c+WGDBki2iuB9N5775Ux3t/av3//lOeeLPyogt1Vxu3R9SvppO6qP5MmTQIQv3c5bOiKTlGUyGOCXNFEabtLnvDLo0sJtoPLHwP8c8gSwY2/E22lSga/CiNBw4EJDmIUFRVlxQ6AXhNZIClb6IpOUZTIozc6RVEiT0EEI5Ts0rp1a+m1wHmJ6ZKuu8pNwnnLVi7d1VGjRok+66yzRHNhVkcy1VCUwkBXdIqiRB690SmKEnk06lpYZD3q+uc//1nGu3XrJtptnTvhhBMyPV1e2L59u2iv7Wep8vHHH4vu3LmzRl3Dg0ZdFUVRAL3RKYpSAQjadd0I4EcAUe/OUhu5+Y6NrLUZhwJjdliB3M0zTOTiO2bFDoBeE1kgKVsEeqMDAGPMF9l6vhFWCuU7Fso8M6EQvmMhzDFT8v0d1XVVFCXy6I1OUZTIk48b3bA8nDNoCuU7Fso8M6EQvmMhzDFT8vodA39GpyiKEjTquiqKEnn0RqcoSuQJ9EZnjDnVGLPYGLPMGDMwyHPnCmNMA2PMNGPMQmPMAmPMzbHxmsaYKcaYpbG/D0h0rKBQO4QHtUVAWGsD+QOgEoCvARwCoAqAuQBaBnX+HH6vIgDtYnpfAEsAtATwIICBsfGBAB7I91zVDuGxg9oiWFsEuaI7GsAya+031tpfAbwEoGeA588J1tq11tpZMb0dwCIA9VDy3VzBtZEAzsnPDHdD7RAe1BYBEeSNrh6AVfTv4thYZDDGNAbQFsAMAHWstWtjL60DUCdP0yqL2iE8qC0CQoMRWcIYUx3AGAD9rLXb+DVbslbXPJ4AUDuEhzDZIsgb3WoA3Byzfmys4DHGVEaJQUdba8fGhtcbY4pirxcB2JCv+ZVB7RAe1BYBkdGNLsWI0ecAmhljmhhjqgC4GMD4TM4fBkxJt+znASyy1j5CL40H0CumewEYl+N5JGsLtUNu56HXREhsEUeQESMAp6MkAvM1gDvzHR3Kxh8AXVCyBJ8HYE7sz+kAagGYCmApgHcA1MzhHFKyhdohHHZQW+TOFmX/pL0FzBhzLIB7rLWnxP59R+zG+ZdyPuN5svbt24ueOXNmWvNp0KDUA1i1alU57wSaNm0qeu+99xbtyocDwF577SV6zpw55R6vYcOGoleuXCn6yCOPFD1//vxyj1G3bl3Ra9as8TvGJutReytVW2j57ozJih1i70naFi1bthS9cOHChO9P9PuXjesuHxx11FGi586d62mLsmTS7tArYnRM2TcZY/oA6FPegb744gt+f1qTue2220T37du33Pdyd3pus1dcXCy6VatWovfff/9yjzdo0CDR1157reiJEyeK5huxF/y5u+++2+8YK3w+ntAWydhBSZq07QCkb4uXXnpJdOvWrRO+P9Hv32effSa6UqVKqU4nb0ydOlV07dq1/WwRR877ulprhyFWuUBXEvlD7RAe1BbBk8mNLuWIUatWrTBuXMnzR3Yf013FMYlWccxpp50mulq1ahmfm1dj8+bNE71o0SLR/IjA6/veddddonlFl2glGCOy0bsCIyM78O8Iu5IdOpQU5t1nn308P7djxw7R/Puc6HfHbxW3du1a0dlocJ5tateunfJnMom6RjJiVKCoLcKB2iGkpL2is9buMsbcCGAySqJNI6y1C7I2MyVp1BbhQO0QXipMA+tsNzROhjfeeEP02WefLXqPPdJeSGe9gbWSFoE2sH7uuecAACeffLKMcaQ/VXbu3AkAqFy5ctrHSAXOZqhSpUq2D68NrBVFUQC90SmKUgEIrevKCY6c+JgMnCfXv39/d+6UjpFL3n77bdHsjiSBuq7hIFDXNRE///yz6HXr1olu0qRJWsfjzIFk8vUcb731luhTTz01rXOngbquiqIoQB5XdJzdXL9+fdGHHXZYYPNJlvvvv18057t58cADD4i+/fbbsz0VXdGFg0BXdO+++y4AoFu3btk4pSf//e9/RU+YMEF0z56ldUCXLFkiunnz5uUeL1HeaHls2bIFAFCzZs1k3q4rOkVRFEBvdIqiVABCG4wIK+7nxT+3jz/+WHSXLl1Ed+3aVfT777+fjdOr6xoOQhWMSJXvvvsOQHxVnpNOOinh5zj/9J///KfoN998EwBw4IGlRURatGghulevXqL5uuHzt23bNqm5e6Cuq6IoCqA3OkVRKgA5L9PEtG7dWnJtuNDktm2lfTP222+/rJ5z9eqS4hFTpkyRsauuusrzvQsWlG5LPOKIIzzf8z//8z8AgCeeeELGOnXqJJqjV7xMv+eeezy1ojg+//xz0R07dszZeQ44oPy+0Y0bNxbNmQPHHnusaC5+ydsbHX//+99F8zWRr3xWXdEpihJ59EanKErkCdR1tdbil19+Ee3IxnJ2xowZog8//HDRXq6w37nXr1+f8DyNGjUC4F+w89tvvxXNW72yUeAzahxzTGmV8csvv1z0iSeeKNrvEcKAAQMAxPfX4Ij3v/71L9H8uxFm/NxVF9Xv3LlzVs/HkdH//Oc/ouvUKe0rzT1V+FpZsaK0grmrDMQ9LS688ELR33zzTZZmnD66olMUJfLojU5RlMgTioThUaNGib7iiit2e33Xrl2iOYKTgyJ+nvTo0UO068TEncG++uor0WeeeaboyZMni2Y3IQMKPmH4oosuEv3444+L5j4A7CK99957ojkhld0kr8+9+uqroi+++OL0J+xNzhOGud8IP4px/Pbbb6L9ej9w3wneJ+v2Y7Mtksl2WLp0qehTTjlFtCvguXjxYs/P8aOIYcOGifayIcOPktidLoMmDCuKogB6o1MUpQIQaNS1bdu2+OCDDwDE75vzcleZPfcsneZDDz3k+R6O7BxyyCGivfamcs8GXm5zHf5XXnlFNM/Vq5k1z4mjUcm4q8uWLQMAHHrooTK2YcMG0QcddFDCY4QVtptr2Td8+HAZ44ie+70A4stiffTRR6KrVq0q2tnHr3ApN0UPM0cddZSUYapVq5aMe7mrTDINp9u3by/6ww8/FH3cccft9l6/niru9xMAfve734letaq0Tzf/7iYikbvK+Lmr6RTlTbiiM8aMMMZsMMZ8SWM1jTFTjDFLY3+Xn2qtZIvGaotQoHYoMJJxXV8AULYu8kAAU621zQBMjf1byT2boLYIA2qHAiOh62qt/cAY07jMcE8AXWN6JID3ACQspzt79mxZFrOrcsIJJyQx1RJuvfVWz3F2V7m9WqJkZL+KxuxicZVVB0cDuWSNH+w6s1vslv3cGrEcd/UHAFvKjKVliyDgJGDXso/h/cccAeS9zwy/x8tlLS4uFj1y5MjUJpsaWbPD3Llz41zWZHn66adFX3fddQnfzz/T2bNnAygt1wTE95eYNGmSaN7ryu4q4+VmJ7MJgI/XoEED0bNmzQIAtGvXzvNz/N5kSfcZXR1r7dqYXgfAN/ZrjOkDoE+a51ESk5Qt1A45R6+JEJNxMMJaa8vLy7LWDgMwDIjPGUq0iuMVExew5KohvEq66aabRHN+nXsP/w/jKpoAQL169URzjh7DOUvff/89AOD//u//yp1/WRL9D3fOOeekdDwvyrOFnx2yDQcSBg0axOcHADz11FMyNnjwYNF+qzjmzjvvLPd1/h3YuHFj4snmiHSviVSCULyKY4/iyiuv9Hz/1VdfLXrhwoUAgNNPP13GOOjw5Zfy6BErV64sdx5AadCAg3zcmcyPHTt2eI67yiicS1hUVCQ6nQpH6aaXrDfGFAFA7O8NCd6v5A61RThQO4SYdG904wG4+si9AIzLznSUNFBbhAO1Q4hJ6LoaY15EyUPW2saYYgBDAAwF8IoxpjeAFQAu9D9CerC7yhx99NGiufoFM3ToUNHOZeRcN3ZXucige0gLAK1atRLNS3LnOk+bNk3GnnnmGdHXXnut55yyRBMA0xGwLRJx9913i2Z3lYNCbjscP9z2c2+40gsHHTjP0dmVHyGMGxfYvSUndli+fHlak/FzV8eOHSv63HPPFV2jRo1yj+fnNrveEABwxhlniL7vvvsAxD/64ZxHP5o1a+Y5nkyOYKokE3W9xOel7lmei5KY5T77+tQWwaJ2KDB0C5iiKJEn0OolHTp0sK4IIm8PSoXp06eL5hr2iQp5/vWvfxXNeUu8TOYqF+yucr6Xi+ZxlJS3zOSY0FQvYfeHq7dwFRIu5pgoqszbiEaPHi2atzExY8aMAQD84Q9/kLEff/wx0bSzRc6rl3zyySeivbZs5RKOXu+zzz4J3+96UPTv31/GuAUob8f73//936Tnwb8Hl112megy17pWL1EURQH0RqcoSgUg0Oolu3btwpYtZXfOJHY7uea/X/UD/pyXe8vLanajXnjhBc/jcdUMrkhyzTXXeL4/V/A2nURt6oKEk7LZXWXYBXKRPNcuEohvk8dR7urVq4vm3w3WridEgO5qznn00UdFp+KucmWS448/PuH7XdUYvpaGDBkimntT8CMKTh7m7Vvu/ZyIz3b++uuvRfsVDJ03b57o1q1bA4h3V5l0eszoik5RlMijNzpFUSJPoK7rpk2b8Pzzz+82nmgpyonBnITqB0djHRx1bd68uWi/iCkX8uTKDtnGtX/kBEs+d5jcVYbtwPtKua8DJ8Amiu5z20Le98p7HDdt2iR6woQJKc44/Jx//vmib7nllnLfu3XrVtGuj0lZXC8HoKTorcNFrOvWrStjP/30k2i3nxsA3nnnHdGnnlpamYoriDjbciYFJyhzbxD3+14W567mCl3RKYoSefRGpyhK5AnUdV29ejXuuOOOjI6RbotD7k/AS3C/0ky8X5aX76nw6aefiu7UqZPne7z2BHIR0Tlz5ohu06ZNWvPIBezecBSbk4Rr1qwp2kXeeD8qR7w5Gs+uGLuufi5aIePXR4Vdfff7x8nT3LvEuaJA/PXBv+evv/76bue+5557RLu+FUB8si/bkEunHXHEEbsdj/fI8vXD0Vq/RzEvvvii6Esu8dt1mj66olMUJfLojU5RlMgTqOvqh1+rwlTw60Hh6s9ziRk/2K3iNogcYXLlmU466SQZ467w5513nmjeL8s89thjovv16wcgfm8jEyZ31Q9O6OaoayqwzbizOz9a4N+TqMB9VBivTATucs/wHm1XMgkAbrvtNs/3T5w4EUB8tW5+FME2dO8F4iOzHCUfMWIEgPikb36cwW4sJxKfcsopojkq7H7nufwZJ4Z37556kRhd0SmKEnkKrnpJMnC1Efcw1S9Phx+CXn/99aJ/+OGHpM+3YMEC0V4PaZPBLxetDKGpXpJt+H93XkXw7ycHJvLZEwJZrF7SoUMH66p7pLK1ibuAcU4dd8nj3+GBA0u7L7pOw1/DAAAaRUlEQVSgzubNm2WMPRg+XseOHUUvXbpUNPescF4O93LgLWwcoOKVGW/1Y1zwgovilvOz0eoliqIogN7oFEWpAATquibjMu3atQtAfP5Yhw7eK1PO/WG3hhvqumoJfvlyXG+f3dhswz0mOJCRIpF1XRmucBF119XPFvyw323Vcn03gPhHMZy/xq5hnz6lrWPffvtt0W5LJQcGTjvtNNGcc3rvvfeK/sc//iGaq5e4VolcPJXhY3CR1ksvvdTz/C1atIg7bln4Wt5jjz2y47oaYxoYY6YZYxYaYxYYY26Ojdc0xkwxxiyN/R3OTZnRornaIRRU1muisEjGdd0FoL+1tiWATgBuMMa0BDAQwFRrbTMAU2P/VnJLsdohNOg1UUCk7LoaY8YBeDL2p6u1dm2sYe971trDEnxWTsY5az179kxpDongJfZVV10FwN915bw93hLDy/pU8KtznyVkmZ4tO4SFQo66BmULzpG7+uqrRbPrumPHDtHsJnLvB682g2+88YZobvt51113iU7UwpAr2qS7VZPhSke9e/cWPX/+fNFHHnlkUq5rSjkexpjGANoCmAGgjrV2beyldQDq+HymD4A+Xq8p6aF2CA9qi8Ig6airMaY6gDEA+llrt/FrtuS/Xc//may1w6y1HbL18Laio3YID2qLwiEp19UYUxnAfwBMttY+EhtbjDy7TFzEj4tzclFGF7HiZfXf//530YMHDxbNy/6QMgvAJoTMDtmAE1CffPJJ0WF1XQEciyxcE61bt7YuobZRo0ZJT4Ddy1SLVrrtkLxtkl3Xb7/9VrTLgsiEf//736I50sok6hvDrF+/XnSdOnWyFnU1AJ4HsMgZNMZ4AL1iuheAcWU/q2SdRlA7hAW9JgqIZJ7RdQZwBYD5xhiX3DYIwFAArxhjegNYAeDC3ExRIWoB6KZ2yDvVoddEQRG6hGEvtm/fLtqvx0PXrl1F815XV0GE+xf4JTZu2LBBNEeyEsFJmsl0Nk+F1157TfT5558f2YRhrnzBUTWOlh988MGi8+265jphOBF8HXClkHbt2ol2lXGA+J+da6HJj3O4MCcX2+RKMunC7UJTcc+TRPe6KoqiAHqjUxSlApC3wpvDhw8XfdFFF4l2pV44AuqXqPi3v/1N9NixYzOe01NPPZXW5zJxV12vBHYXGC6IGGW+/PJL0VwOiBO6mzZtKjrPrmugvPXWWwDie0A49xOIL3k2atQo0VwqiSOVXvz+97/PeJ5+5MBdTRld0SmKEnn0RqcoSuQJNOratm1b68oV+bU9c+VfuE9DMnBU6eWXXxbdpUsXAEBxcbGMBbWUZvfC7/umSGSjrozbnwwAzz33nOj3339fdN++fQEACxcuDGxeRN6jrmHC3UOSqZDMbS39HtekiEZdFUVRgICDEXPmzPFc2ezcuVN05cqVkz4eVynhaiOp5P4k04GMa+u7bTOcr+S3Qkx3FceBmGrVqqV1jEKGA0vc4apHjx6iXaUZtjvnMxYK7du3x+effw7Av2tcIlavXi26Xr16SX+OC9d269bN8z3sHXG+3qJFi0QnWslxLuRf/vKXpOeXTXRFpyhK5NEbnaIokSdvW8B+/vlnGd9rr73SOl6Oi1wGSpI/jwoRjGC4hd6f/vQn0a7aCVfuCDAwkZNgBFfVueGGG9I6XpZ6k2QM9/2oVKlSWsfgyimTJk0SzddHjx49NBihKIoC6I1OUZQKQCiql0TJBWXYjTjiiCNEp1IZpQwVznUNKVlzXdu0aWNd9LNWrVoyzteEK1aZTJ4as3btWtFctDQRw4YNE80tExORTJWhVBg5cqToXr16iZ41a5bodu3aqeuqKIoC6I1OUZQKQNCu60YAP6Kk70GUqY3cfMdG1toDMz1IzA4rkLt5holcfMes2AHQayILJGWLQG90AGCM+SLq3Y8K5TsWyjwzoRC+YyHMMVPy/R3VdVUUJfLojU5RlMiTjxvdsMRvKXgK5TsWyjwzoRC+YyHMMVPy+h0Df0anKIoSNOq6KooSefRGpyhK5An0RmeMOdUYs9gYs8wYMzDIc+cKY0wDY8w0Y8xCY8wCY8zNsfGaxpgpxpilsb+zUks9G6gdwoPaIiCstYH8AVAJwNcADgFQBcBcAC2DOn8Ov1cRgHYxvS+AJQBaAngQwMDY+EAAD+R7rmqH8NhBbRGsLYJc0R0NYJm19htr7a8AXgLQM8Dz5wRr7Vpr7ayY3g5gEYB6KPlublfySADneB8hcNQO4UFtERBB3ujqAVhF/y6OjUUGY0xjAG0BzABQx1rrykesA1AnT9Mqi9ohPKgtAkKDEVnCGFMdwBgA/ay12/g1W7JW1zyeAFA7hIcw2SLIG91qANystX5srOAxxlRGiUFHW2tdC6v1xpii2OtFADbka35lUDuEB7VFQGR0o0sxYvQ5gGbGmCbGmCoALgYwPpPzhwFTUg3xeQCLrLWP0EvjAbhqgb0AjMvxPJK1hdoht/PQayIktogjyIgRgNNREoH5GsCd+Y4OZeMPgC4oWYLPAzAn9ud0ALUATAWwFMA7AGrmcA4p2ULtEA47qC1yZ4uyf9LeAmaMORbAPdbaU2L/viN24/TtUKslvDNmk/WovZWqLfzsUL16ddE1a9YsOeGm0hJi3GC5Ro0aotesWSOamxUngrs51a5dW/SqVau83h4msmKH2Hv0msgMT1uUZc8MTuAVMTqm7JuMMX0AJF94XimPFT7jCW2RjB3atGkj+vLLLwcAPPfcczLGN8KePUuzIAYPHiz6xx9/LO8UcRx22GGie/fuLbpv375JHyNPpG0HQK+JLONnizgyudElhbV2GGKVC/R/r/yhdggPaovgyeRGF9mIUQGSFVt88803oufMmQMAaNq0qYy9/PLLoj/66CPR3Gg4Fdw5gHCu4vi7N2nSRPQ777zj9xG9JkJKJlHXSEaMChS1RThQO4SUtFd01tpdxpgbAUxGSbRphLV2QdZmpiSN2iIcqB3CSygaWCtJE3gDa268zdHV7t27i165cqXozz77LNPpxUV3OYr7xz/+UfRJJ50EAPj2229lbOzYsaJnzpwper/99hO9bVtcgn65HHNMaRyhW7duov/yl79krYG1XhMZow2sFUVRAL3RKYpSAch5ekkhUalSJdG//fZbHmcSHk477TTRxcXFovfcs/RXZ8uWLaIHDBggmt3ehx56CACwcOHChOe86aabRH/yySeip02bJvrXX38FAFxwwQUyxi5qyS6kEq6//nrR06dPFz18+PDdzn3JJZeIXrp0qec8lMJDV3SKokSeyAcj9tlnHwDALbfcImO9evUSzQ/SOTufH3KHiMCDEcyECRNE9+/fXzSvfD799FPRr776quh//vOfAIANG0oLVlSrVk00Bx3cag2IXzl65eu5oAQAnHNOaR1HPnbHjh1FN2zYUPSZZ54pet999wUA1K9fX8ZmzJghusxKVIMR4UGDEYqiKIDe6BRFqQBE3nWtW7cuAOC9996TsUMPPVQ0f//TTz9d9OTJk3M/udTJq+uaDCtWlO6xPvLII0VXrVoVQPxWqlq1aonesWOH6B49eog+77zzRK9bt0708uXLAcQ/YvDT8+bNEz1wYGmJuMcff1x0lSpVAMS71hzo4Pm9/PLLBe26OluwS3/ZZZeJvvHGG0UXQFBOXVdFURRAb3SKolQAIp9H99NPPwEAFiwo3XJYr15poyWOztWpU8dznN2WigznGbrCnAAwaNAg0WeccYZojp5ec801AOLd0qKiItHsxk6cOFE016xr3ry56LZt2wIotS8AHHzwwZ7zfvvtt0VfeOGFov/zn/+Ifu211wAATz31lIxxhLZly5aiuYpLIeLseNttt8mY+3kCwPPPPy967ty5onfu3BnA7HKDrugURYk8eqNTFCXyRN513b59O4D4RNZOnTqJZhf1oIMOEs39DKLsurZu3Vr00UcfLZpLqDt4S9ftt98u+oYbbhD9/fffi27QoLQGpSvTzlVA+OfKP29O3ObjcbHPrl27Aojf6uWXQXD88ceLHjJkiOe83VzYPbv//vtFu0hlFKhcuTKA0mR6oDQ7AYjPPnDRbQDYvHlzALPLDbqiUxQl8uiNTlGUyBN519VFmLiYox+8fOeoonN/gfT7I4QV7tq1evXu7Q24Igj3lOAEUz+4yKUrisk/P35swFE/dh/fffdd0fxZl+DLkVGO3HKEmM/jor9AqQsHlPas4McXp5xyiujFixeLzkZx0Xzifr4ff/yxjPEjhQ4dSvNv3R5gQF1XRVGUUKM3OkVRIk/kXVfn4nC5HzdWFnafTjjhBNGccOrl3hUyX3/9tad28H5VjoxyPwWOnq5aVdq/mV1Ql5zLe475cYJrmA0ALVq0EH3vvfeKrlGjhmiX3H3uuefKWJ8+pT2hOaLIJaDYpfX6Pdh7771Fc2I5f/couq4MR+J5b3JIS5clRcIVnTFmhDFmgzHmSxqraYyZYoxZGvv7gNxOU4nRWG0RCtQOBUYyrusLAE4tMzYQwFRrbTMAU2P/VnLPJqgtwoDaocBI6Lpaaz8wxjQuM9wTQNeYHgngPQC3I4Q4t4rdDW6Fx3sv2V1lF4eTJvPsuv4AYEuZsYxscd1114n+8MMPRX/5Zclihd2b2rVri+Zu9dWrVxfNrqY7BgD88ssvcX+X5dFHHxXt17tj06ZNol10lKOC/F5OJF6yZInoJ554QvS4ceN2mwe7Z0OHDvWcK3JghyBxrusXX3whY1u3bhXdqFEj0Zw8zFWW169fn8spZp10n9HVsdaujel1AOr4vdEY0wdAH7/XlYxJyhZqh5yj10SIyTgYYa215RUPtNYOAzAMyE+RQZd7xduHOFeKV3TMAQeUPmLhXLIwU54t/OzAPTN4u5UXTz75pGiu8vHBBx+Idr0hAODiiy8W7baM8WqaO3JxPt+IESNEjx49WjTPtWfPngDit3dxXhxXTuEgC2//mz9/vmhuzu047rjjRPNKnouLehH2a8LBP6OpU6eKPuuss0S3a9dONFeSKbQVXbrpJeuNMUUAEPt7Q4L3K7lDbREO1A4hJt0b3XgArpVWLwC7P+xQgkJtEQ7UDiEmoetqjHkRJQ9ZaxtjigEMATAUwCvGmN4AVgC40P8I4YBdI3ZZ+KE1ww/e+YF3nmkCYDoytEXDhg1x5513AogPQHC/BAfntF166aWi/SqFcKFGdhPPP/98AED37t1l7MEHHxTN1WXeeOMN0Zxfx1v0XFtCrpDCuW4cVHj44YdFz549W3Sifgi8FYpz8VasWJEVO+Sbn3/+WbQrPArEFzjl9o9NmzYVzY8rCoFkoq6X+LzU3WdcyR3LfRqBqC2CRe1QYOgWMEVRIk/kt4B5wXlx7IKx5tywQom6JkuVKlXEDeSKJNzvYcKECQDi8+JchQ/A280F4nOwOErn3ETeVsWuK28jY3eJo9/smrrtY7zNjPuCcBSVo6Rciebpp58W7QpycjSRI/V87KjA1WA4As75ilyMlbcDFhq6olMUJfLojU5RlMhTIV1X3u7CSbL777+/aK5iwW6sG+c2e4VGcXExbr31VgDADz/8IONXXXWV6Ndffx0A0L9/fxnj5N3f/e53ojlRl6OgRx11lGhX2JGjl349Ixh2TTmp1SV6czLwlClTRD/wwAOi/ZJbhw0bJnrjxo27vc5RRo4mRxGOUnMFGo48d+nSRXSrVq1E81a/sKIrOkVRIo/e6BRFiTwV0nVl1q1bJ5pdV4aLD7r3FLLrumPHDnz11VcAgAMPPFDGx4wZI5qjnQ52VxlOvOWINif+usg1R/TYjeVWivPmzRN92mmniT7kkENEO/eKk15btmwpmvtEMOxacyKxc2NfeuklGXv11Vc9jxF1OKGeK/pw1JV/F9R1VRRFCQF6o1MUJfJUeNeVS/8cfvjhojl5mEsBubI1HOHjcjeFBkcbJ0+eLLp9+/YAgGbNmsmYa1kIxH9ndkH5UQD/bKdNmwYAeOaZZ2SMe0ZUrVpVNO8tfv/990W7/bI8L3aFeb/mgAEDRPfr1080f99BgwaJ5kTiio6zFRDfR+WCCy4Qze0RR44cCQDYsqVsLdLwoCs6RVEij97oFEWJPBXedeUooV/PAY7m/f73vwcQn6jqIpiFDlfUdXtZOQr55z//WbT7OQDA4MGDRXPklqvyOheU99b6wa0pOZG1qKhItHuEwMnAXOqJ+yFwS0SOpPL7XYu/rl27ytjSpUtFcxTXqy1klOAeG1zGiytG8+OcmjVrAgC+++47GfMr45UvdEWnKErkqfAruvHjx4vmxr1cNJJxD2c5cBGVFR1XG3EP+/kBs/ufGwDefPNN0YsWLRJ9zTXXiG7YsKFo3j6WCK6qwU2rufKJC2RwFRVe8XHXt5NPPtlzfs8++6xolxfJTa15CxgX/Yz6im7btm2ily1bJppXuFylxgV+HnvsMRkL2zWhKzpFUSKP3ugURYk8Fd515aKM7IJ17NhRNDdUdu4Rb43hqhq87C80uJqI2xrGD/L5gTwX6TzxxBNFc6tCrhLjBefRMccee6zoXr16ifZqMM7BiPPOO89zrrzti4NPDLtoSinvvvuu6H/84x+ie/fuLdoF8RK1gcwnCVd0xpgGxphpxpiFxpgFxpibY+M1jTFTjDFLY3/vvjlSyTbN1Q6hoLJeE4VFMq7rLgD9rbUtAXQCcIMxpiWAgQCmWmubAZga+7eSW4rVDqFBr4kCIpkuYGsBrI3p7caYRQDqAeiJkjaIADASwHsAbs/JLHMIR4e4ECNvcWncuLHogw8+GEB8vhX3lMix6/oTkLkdGjVqhLvvvhsAsHnzZhlnd8+9ztu+uPAi57pxtREulMnFKt22Lm6lx3D1koEDS+8P/H7eGua2fnFe3ssvvyya3WYuHsrvyYCd1tpZQDSvCYZzS9966y3RXJDTVff55ZdfgptYiqT0jM4Y0xhAWwAzANSJ3QQBYB2AOj6f6QOgT/pTVMqSqR1q1aqV+0lWEPSaKAySjroaY6oDGAOgn7U2btliS9KgPVOhrbXDrLUdfPpgKimSDTtwaXglffSaKBySWtEZYyqjxKCjrbXOl1lvjCmy1q41xhQB8O5/F3K48gZvN5ozZ45oTo50cAIpR2VzjEEW7LB161ZMnDgRQHySLSfQusgnJ+/yz+rJJ58UzVUt2M3nZGPXm4Nde7eNC4hPVubzzJo1SzRHhZ1bzInBP/74I7zghOGhQ4eK5oKR3DskGaJ8TfjBidJcBebyyy8HEP9zfuedd0Tz71C+SCbqagA8D2CRtfYRemk8ABf77wVgXPanp5ShEdQOYUGviQIimRVdZwBXAJhvjHHLnEEAhgJ4xRjTG8AKABfmZooKUQtAN7VD3qkOvSYKimSirh+hxGXyont2p5NfeN/kU089JZrr47t2h9xrYZ999hHNVU9yUMFhps9znZTs8P3338dVGXEcdNBBojt37gygNPoKxBew5IRRblVYp07p83ev78/uJ7v8/N61a9eKnj17tmiOALvkYH7EwD/7+vXri+ZCkhwtT9VdJX6w1laIa4Lh4qTcM8W5pvwYhCP43FIzX+gWMEVRIo/e6BRFiTwVfq+rH9zy7YUXXhDt2h1yUclC3t/KcHKoK6vEXdv9eixceumlnsfgsj4uwdev7BOfZ9SoUZ7j7AJxlNbBbSk5KszJyIXc3yNM8GOe4uJiAPGPCzihPAzoik5RlMijNzpFUSKPCbK2uzEmXIXkCw+/qGtK+NmBXTwXgZ05c6aMsYvO28huvfVW0bw3dfjw4aKdq8+RTo5md+9eGqy88sorRaeSbMrRb8YvkTgDsmIHoHCviUMPPVR0mzZtAMRHwAMse5WULXRFpyhK5AnXE0MlEKpXry7/C3/00UcyvmDBgt0059Hxg3zeSsVeAXfWmjRpkmj3oPqss86SMV6tJbOKa9WqlWi3TWzhwoUydtJJJ4nmnC5eWSrZgVdshVC0VFd0iqJEHr3RKYoSedR1rYDs2LEjaXfjvvvuS/gebnLtt92nRYsWAIA1a9bIGBdyZHf1sssuE7148WLRvO3I5etxAIUbcHMepKLoik5RlMijNzpFUSKPuq4VkF27dmHdunVZOx73hvDDFWK89957ZYy3d3GBTW4fydVOvKqNcA5f06ZNRY8bl1opONcDgefErRSVwkZXdIqiRB690SmKEnmC3gK2EcCPADYFdtL8UBu5+Y6NrLUHJn5b+cTssAK5m2eYyMV3zIodAL0mskBStgj0RgcAxpgvot79qFC+Y6HMMxMK4TsWwhwzJd/fUV1XRVEij97oFEWJPPm40Q3LwzmDplC+Y6HMMxMK4TsWwhwzJa/fMfBndIqiKEGjrquiKJFHb3SKokSeQG90xphTjTGLjTHLjDEDgzx3rjDGNDDGTDPGLDTGLDDG3Bwbr2mMmWKMWRr7+4B8z9WhdggPaouAsNYG8gdAJQBfAzgEQBUAcwG0DOr8OfxeRQDaxfS+AJYAaAngQQADY+MDATyQ77mqHcJjB7VFsLYIckV3NIBl1tpvrLW/AngJQM8Az58TrLVrrbWzYno7gEUA6qHku42MvW0kgHPyM8PdUDuEB7VFQAR5o6sHYBX9uzg2FhmMMY0BtAUwA0Ada+3a2EvrANTJ07TKonYID2qLgNBgRJYwxlQHMAZAP2vtNn7NlqzVNY8nANQO4SFMtgjyRrcaQAP6d/3YWMFjjKmMEoOOttaOjQ2vN8YUxV4vArAhX/Mrg9ohPKgtAiLIG93nAJoZY5oYY6oAuBjA+ADPnxOMMQbA8wAWWWsfoZfGA+gV070ApFYJMneoHcKD2iKoOcUiIMGczJjTATyGkmjTCGvtnwI7eY4wxnQB8CGA+QD+GxsehJJnEq8AaIiSkkgXWmu35GWSZVA7hMMOgNoiKFvoFjBFUSKPBiMURYk8eqNTFCXy6I1OUZTIozc6RVEij97oFEWJPHqjUxQl8uiNTlGUyPP/MgabGlld78sAAAAASUVORK5CYII=\n",
369 "text/plain": [
370 "<Figure size 360x360 with 9 Axes>"
371 ]
372 },
373 "metadata": {
374 "needs_background": "light"
375 },
376 "output_type": "display_data"
377 }
378 ],
379 "source": [
380 "mnist_data.update_noise(0.2)\n",
381 "mnist_data.update_rotation(15)\n",
382 "mnist_data.update_aug_possibility(0.5)\n",
383 "mnist_data.plot_images()"
384 ]
385 },
386 {
387 "cell_type": "markdown",
388 "metadata": {},
389 "source": [
390 "## 定义模型\n",
391 "下面我们先定义一个 DrawCallback 类,用来在每个训练 epoch 结束后,打印出训练曲线。"
392 ]
393 },
394 {
395 "cell_type": "code",
396 "execution_count": null,
397 "metadata": {},
398 "outputs": [],
399 "source": [
400 "class DrawCallback(Callback):\n",
401 " \"\"\" 进行模型训练的 callback 类,用来绘制训练图\n",
402 "\n",
403 " Attributes:\n",
404 " init_loss: 初始的 loss 值\n",
405 " epoch_data: 已经完成的 epoch 数的列表\n",
406 " loss_data: 每个 epoch 的训练数据的 loss 值的列表\n",
407 " val_loss_data: 每个 epoch 的测试数据的 loss 值的列表\n",
408 " accuracy_data: 每个 epoch 的训练数据的 acc 值的列表\n",
409 " val_accuracy_data: 每个 epoch 的测试数据的 acc 值的列表\n",
410 " best_val_acc: 测试数据的最佳的 acc\n",
411 "\n",
412 " \"\"\"\n",
413 " def __init__(self):\n",
414 " \"\"\"初始化 DrawCallback 类\n",
415 " \"\"\"\n",
416 " super().__init__()\n",
417 " self.init_loss = None\n",
418 " self.epoch_data = []\n",
419 " self.loss_data = []\n",
420 " self.val_loss_data = []\n",
421 " self.accuracy_data = []\n",
422 " self.val_accuracy_data = []\n",
423 " self.best_val_acc = 0\n",
424 "\n",
425 " def runtime_plot(self, epoch=None):\n",
426 " \"\"\"绘制训练曲线\n",
427 " :param epoch: 当前进行到的 epoch 数\n",
428 " :return: None\n",
429 " \"\"\"\n",
430 " # 总的 epoch 数\n",
431 " epochs = self.params.get(\"epochs\")\n",
432 " # 定义图片尺寸\n",
433 " img_figure = plt.figure(1)\n",
434 " img_figure.set_figwidth(16)\n",
435 " img_figure.set_figheight(5)\n",
436 " # 绘制 loss 曲线\n",
437 " ax1 = plt.subplot(1, 2, 1)\n",
438 " ax1.set_ylim(0, int(self.init_loss * 3))\n",
439 " ax1.set_xlim(1, epochs)\n",
440 " ax1.plot(self.epoch_data, self.loss_data, 'b', label='loss_train')\n",
441 " ax1.plot(self.epoch_data, self.val_loss_data, 'r', label='loss_val')\n",
442 " ax1.set_xlabel('Epoch {}/{}'.format(epoch, epochs))\n",
443 " ax1.set_ylabel('Loss')\n",
444 " ax1.legend()\n",
445 " # 绘制 acc 曲线\n",
446 " ax2 = plt.subplot(1, 2, 2)\n",
447 " ax2.set_ylim(0, 1)\n",
448 " ax2.set_xlim(1, epochs)\n",
449 " ax2.plot(self.epoch_data, self.accuracy_data, 'b', label='acc_train')\n",
450 " ax2.plot(self.epoch_data, self.val_accuracy_data, 'r', label='acc_val')\n",
451 " ax2.set_xlabel('Epoch {}/{}'.format(epoch, epochs))\n",
452 " ax2.set_ylabel('ACC')\n",
453 " ax2.legend()\n",
454 " # 清除历史图片\n",
455 " display.clear_output(wait=True)\n",
456 " # 展示新图片\n",
457 " plt.show()\n",
458 "\n",
459 " def on_epoch_end(self, epoch, logs=None):\n",
460 " \"\"\"在每个 epoch 结束后,存储 loss 值和 acc 值,并更新训练曲线图\n",
461 " :param epoch: 当前进行到的 epoch 数\n",
462 " :param logs: 当前 epoch 返回的 log\n",
463 " :return: None\n",
464 " \"\"\"\n",
465 " # 从 logs 中获取 loss 和 acc 的值并存到对应的 list 里\n",
466 " epoch = epoch + 1\n",
467 " logs = logs or {}\n",
468 " loss = logs.get(\"loss\")\n",
469 " val_loss = logs.get(\"val_loss\")\n",
470 " accuracy = logs.get(\"accuracy\")\n",
471 " val_accuracy = logs.get(\"val_accuracy\")\n",
472 " if val_accuracy > self.best_val_acc:\n",
473 " self.best_val_acc = val_accuracy\n",
474 " if self.init_loss is None:\n",
475 " self.init_loss = loss\n",
476 " self.epoch_data.append(epoch)\n",
477 " self.loss_data.append(loss)\n",
478 " self.val_loss_data.append(val_loss)\n",
479 " self.accuracy_data.append(accuracy)\n",
480 " self.val_accuracy_data.append(val_accuracy)\n",
481 " # 绘制训练图\n",
482 " self.runtime_plot(epoch)\n",
483 " # 如果训练完成,打印最终的结果\n",
484 " if epoch == self.params.get(\"epochs\"):\n",
485 " print('训练完成')\n",
486 " print('训练集准确率: {:.2%}'.format(accuracy))\n",
487 " print('测试集准确率: {:.2%}'.format(val_accuracy))\n",
488 " print('最高测试集准确率:{:.2%}'.format(self.best_val_acc))"
489 ]
490 },
491 {
492 "cell_type": "markdown",
493 "metadata": {},
494 "source": [
495 "然后我们定义一个 NN 类,我们将基于此类来构建和管理模型。"
496 ]
497 },
498 {
499 "cell_type": "code",
500 "execution_count": null,
501 "metadata": {},
502 "outputs": [],
503 "source": [
504 "class NN(object):\n",
505 " def __init__(self):\n",
506 " \"\"\"初始化 NN 类,构建一个多层全连接神经网络\n",
507 " Attributes:\n",
508 " layers: 隐藏层的数目\n",
509 " neurons: 每一层隐藏层的神经元的个数\n",
510 " epochs: 模型训练的轮数\n",
511 " model: 构建后的模型\n",
512 " \"\"\"\n",
513 " self.layers = 1\n",
514 " self.neurons = [8]\n",
515 " self.epochs = 5\n",
516 " self.model = None\n",
517 "\n",
518 " def build_model(self):\n",
519 " \"\"\"构建并编译模型\n",
520 " :return: None\n",
521 " \"\"\"\n",
522 " # 定义输入层和 flatten 层\n",
523 " inputs = Input(shape=(28, 28, 1,))\n",
524 " output = Flatten()(inputs)\n",
525 " # 定义隐藏层\n",
526 " for neuron in self.neurons:\n",
527 " output = Dense(neuron, activation='relu')(output)\n",
528 " # 定义输出层\n",
529 " predictions = Dense(10, activation='softmax')(output)\n",
530 " # 构建模型\n",
531 " self.model = Model(inputs=inputs, outputs=predictions)\n",
532 " # 编译模型\n",
533 " self.model.compile(loss='categorical_crossentropy',\n",
534 " optimizer='rmsprop',\n",
535 " metrics=['accuracy'])\n",
536 "\n",
537 " def start_train(self, data, validation_data):\n",
538 " \"\"\"开始训练模型\n",
539 " :param data: 训练数据\n",
540 " :param validation_data: 测试数据\n",
541 " :return: None\n",
542 " \"\"\"\n",
543 " plot_callback = DrawCallback()\n",
544 " self.model.fit_generator(data,\n",
545 " validation_data=validation_data,\n",
546 " validation_steps=300,\n",
547 " steps_per_epoch=2000, \n",
548 " epochs=self.epochs, \n",
549 " verbose=1,\n",
550 " callbacks=[plot_callback])\n",
551 "\n",
552 " def plot_model(self):\n",
553 " \"\"\"保存模型结构图\n",
554 " :return: None\n",
555 " \"\"\"\n",
556 " keras.utils.plot_model(\n",
557 " self.model,\n",
558 " to_file='model.jpg',\n",
559 " show_shapes=True,\n",
560 " show_layer_names=True)"
561 ]
562 },
563 {
564 "cell_type": "markdown",
565 "metadata": {},
566 "source": [
567 "## 操作面板\n",
568 "下面的代码将创建出一个可交互的操作面板,你可以通过此面板来进行数据增强,模型定义和训练的操作,并查看训练曲线图。"
569 ]
570 },
571 {
572 "cell_type": "code",
573 "execution_count": null,
574 "metadata": {},
575 "outputs": [],
576 "source": [
577 "mlp = NN()\n",
578 "mlp.build_model()\n",
579 "mnist_data = MNIST()\n",
580 "## 定义操作界面\n",
581 "# 隐藏层数目\n",
582 "hidden_layers_widget = IntSlider(min=0, \n",
583 " max=3,\n",
584 " value=1,\n",
585 " description='隐藏层数目:',\n",
586 " layout={'width': '400px'},\n",
587 " style={'description_width': 'initial'})\n",
588 "# 每层隐藏层包含的神经元数目\n",
589 "neuron_widgets = VBox([SelectionSlider(options=[8,16, 32, 64,128,256,512],\n",
590 " value=8,\n",
591 " description='隐藏层1的神经元数目:',\n",
592 " continuous_update=False,\n",
593 " orientation='horizontal',\n",
594 " readout=True,\n",
595 " layout={'width': '400px'},\n",
596 " style={'description_width': 'initial'})])\n",
597 "\n",
598 "# 开始训练的按钮\n",
599 "start_train_widget = Button(\n",
600 " description='开始训练',\n",
601 " button_style='info',\n",
602 " tooltip='点击开始训练',\n",
603 ")\n",
604 "\n",
605 "output_widget = Output()\n",
606 "\n",
607 "data_head_widget = HTML(\n",
608 " value=\"<h1>准备数据</h1>\",\n",
609 " placeholder='',\n",
610 ")\n",
611 "\n",
612 "data_aug_possibility_widget = FloatSlider(\n",
613 " value=0,\n",
614 " min=0,\n",
615 " max=1.0,\n",
616 " step=0.1,\n",
617 " description='每个 batch 内增强图片的比例:',\n",
618 " disabled=False,\n",
619 " continuous_update=False,\n",
620 " orientation='horizontal',\n",
621 " readout=True,\n",
622 " readout_format='.1f',\n",
623 " layout={'width': '400px'},\n",
624 " style={'description_width': 'initial'}\n",
625 ")\n",
626 "\n",
627 "\n",
628 "data_noise_widget = FloatSlider(\n",
629 " value=0,\n",
630 " min=0,\n",
631 " max=0.2,\n",
632 " step=0.05,\n",
633 " description='为图片增加白噪声点:',\n",
634 " disabled=False,\n",
635 " continuous_update=False,\n",
636 " orientation='horizontal',\n",
637 " readout=True,\n",
638 " readout_format='.2f',\n",
639 " layout={'width': '400px'},\n",
640 " style={'description_width': 'initial'}\n",
641 ")\n",
642 "\n",
643 "data_rotate_widget = SelectionSlider(options=[0, 10, 20, 30],\n",
644 " value=0,\n",
645 " description='图片左右旋转的最大角度:',\n",
646 " continuous_update=False,\n",
647 " orientation='horizontal',\n",
648 " readout=True,\n",
649 " layout={'width': '400px'},\n",
650 " style={'description_width': 'initial'})\n",
651 "\n",
652 "def plot_sample_data():\n",
653 " mnist_data.plot_images(show=False)\n",
654 " file = open(\"data_sample.jpg\", \"rb\")\n",
655 " image = file.read()\n",
656 " show_data_sample_widget.value = image\n",
657 "\n",
658 "def update_data_noise(*args):\n",
659 " mnist_data.update_noise(data_noise_widget.value)\n",
660 " plot_sample_data()\n",
661 "\n",
662 "def update_data_rotation(*args):\n",
663 " mnist_data.update_rotation(data_rotate_widget.value)\n",
664 " plot_sample_data() \n",
665 " \n",
666 "def update_data_aug_possibility(*args):\n",
667 " mnist_data.update_aug_possibility(data_aug_possibility_widget.value)\n",
668 " plot_sample_data() \n",
669 "\n",
670 "data_noise_widget.observe(update_data_noise, 'value') \n",
671 "data_rotate_widget.observe(update_data_rotation, 'value') \n",
672 "data_aug_possibility_widget.observe(update_data_aug_possibility, 'value') \n",
673 "\n",
674 "show_data_sample_widget = Image(\n",
675 " value=b'',\n",
676 " format='jpg',\n",
677 " width=200,\n",
678 ")\n",
679 "\n",
680 "plot_sample_data()\n",
681 "\n",
682 "model_head_widget = HTML(\n",
683 " value=\"<h1>构建模型</h1>\",\n",
684 " placeholder='',\n",
685 ")\n",
686 "\n",
687 "show_model_widget = Image(\n",
688 " value=b'',\n",
689 " format='jpg',\n",
690 " width=250,\n",
691 ")\n",
692 "\n",
693 "def build_plot_model(*args):\n",
694 " keras.backend.clear_session()\n",
695 " mlp.build_model()\n",
696 " mlp.plot_model()\n",
697 " file = open(\"model.jpg\", \"rb\")\n",
698 " image = file.read()\n",
699 " show_model_widget.value = image\n",
700 "\n",
701 "build_plot_model()\n",
702 "\n",
703 "\n",
704 "# 更改隐藏层神经元数量时,更新模型实例的参数\n",
705 "def update_neuron_number(*args):\n",
706 " mlp.neurons = [c.value for c in neuron_widgets.children]\n",
707 " \n",
708 "for e in neuron_widgets.children:\n",
709 " e.observe(update_neuron_number, 'value')\n",
710 " e.observe(build_plot_model, 'value')\n",
711 " \n",
712 "\n",
713 "# 更改隐藏层数目时,动态增加或减小控制每一层神经元数量的按钮,同时更新模型实例的参数\n",
714 "def update_layers(*args):\n",
715 " if hidden_layers_widget.value > mlp.layers:\n",
716 " neuron_widgets.children += (SelectionSlider(options=[8,16, 32, 64,128,256,512],\n",
717 " value=8,\n",
718 " description='隐藏层{}的神经元数目:'.format(hidden_layers_widget.value),\n",
719 " continuous_update=False,\n",
720 " orientation='horizontal',\n",
721 " readout=True,\n",
722 " layout={'width': '400px'},\n",
723 " style={'description_width': 'initial'}),)\n",
724 " else:\n",
725 " neuron_widgets.children = neuron_widgets.children[:-1]\n",
726 " mlp.layers = hidden_layers_widget.value\n",
727 " mlp.neurons = [c.value for c in neuron_widgets.children]\n",
728 " for e in neuron_widgets.children:\n",
729 " e.observe(update_neuron_number, 'value')\n",
730 " e.observe(build_plot_model, 'value')\n",
731 " build_plot_model()\n",
732 "\n",
733 "hidden_layers_widget.observe(update_layers, 'value') \n",
734 "\n",
735 "\n",
736 "data_H_widget = HBox([VBox([data_aug_possibility_widget, \n",
737 " data_noise_widget, \n",
738 " data_rotate_widget]), \n",
739 " show_data_sample_widget])\n",
740 "\n",
741 "model_H_widget = HBox([VBox([hidden_layers_widget ,\n",
742 " neuron_widgets]), \n",
743 " show_model_widget])\n",
744 "\n",
745 "train_head_widget = HTML(\n",
746 " value=\"<h1>训练模型</h1>\",\n",
747 " placeholder='',\n",
748 ")\n",
749 "\n",
750 "epochs_widget = SelectionSlider(options=[5, 10, 20, 50, 100],\n",
751 " value=5,\n",
752 " description='训练轮数:',\n",
753 " continuous_update=False,\n",
754 " orientation='horizontal',\n",
755 " readout=True,\n",
756 " layout={'width': '400px'},\n",
757 " style={'description_width': 'initial'})\n",
758 "\n",
759 "def update_epochs(*args):\n",
760 " mlp.epochs = epochs_widget.value\n",
761 "\n",
762 "epochs_widget.observe(update_epochs, 'value') \n",
763 "\n",
764 "\n",
765 " \n",
766 "def on_button_clicked(b):\n",
767 " with output_widget:\n",
768 " display.clear_output(wait=True)\n",
769 " mlp.start_train(mnist_data.get_batch(32),\n",
770 " validation_data=mnist_data.get_batch_test(32))\n",
771 " \n",
772 "start_train_widget.on_click(on_button_clicked) \n",
773 "\n",
774 "display.display(data_head_widget, \n",
775 " data_H_widget, \n",
776 " model_head_widget, \n",
777 " model_H_widget , \n",
778 " train_head_widget, \n",
779 " epochs_widget, \n",
780 " start_train_widget, \n",
781 " output_widget)\n"
782 ]
783 },
784 {
785 "cell_type": "code",
786 "execution_count": null,
787 "metadata": {},
788 "outputs": [],
789 "source": []
790 }
791 ],
792 "metadata": {
793 "kernelspec": {
794 "display_name": "Python 3",
795 "language": "python",
796 "name": "python3"
797 },
798 "language_info": {
799 "codemirror_mode": {
800 "name": "ipython",
801 "version": 3
802 },
803 "file_extension": ".py",
804 "mimetype": "text/x-python",
805 "name": "python",
806 "nbconvert_exporter": "python",
807 "pygments_lexer": "ipython3",
808 "version": "3.5.2"
809 },
810 "pycharm": {
811 "stem_cell": {
812 "cell_type": "raw",
813 "metadata": {
814 "collapsed": false
815 },
816 "source": []
817 }
818 }
819 },
820 "nbformat": 4,
821 "nbformat_minor": 2
822 }
0 {
1 "cells": [
2 {
3 "cell_type": "markdown",
4 "metadata": {},
5 "source": [
6 "# 5.1 基于搜索的问题求解"
7 ]
8 },
9 {
10 "cell_type": "markdown",
11 "metadata": {},
12 "source": [
13 "现实世界中许多问题都可以通过搜索的方法来求解,例如设计最佳出行路线或是制订合理的课程表。当给定一个待求解问题后,搜索算法会按照事先设定的逻辑来自动寻找符合求解问题的答案,因此一般可将搜索算法称为问题求解智能体。"
14 ]
15 },
16 {
17 "cell_type": "markdown",
18 "metadata": {},
19 "source": [
20 "## 5.1.1 搜索算法基本概念"
21 ]
22 },
23 {
24 "cell_type": "markdown",
25 "metadata": {},
26 "source": [
27 "我们把书中的公交换乘的问题,转为无向图中的的路径寻找问题。\n",
28 "\n",
29 "首先,我们画出如下的无向图。该无向图中有 A,B,C,D,E,F,G 七个节点,其中 A 是起点, G 是目标点。\n",
30 "\n",
31 "点与点之间的连线称为边,边可以有权重,可以代表点与点之间的距离或者从一个点转移到另一个点需要花费的代价。\n",
32 "\n",
33 "下面我们先创建一个图。"
34 ]
35 },
36 {
37 "cell_type": "code",
38 "execution_count": null,
39 "metadata": {},
40 "outputs": [],
41 "source": [
42 "# 首先导入必要的包\n",
43 "from search import SearchGraph\n",
44 "import collections\n",
45 "import matplotlib.pyplot as plt\n",
46 "import collections\n",
47 "from IPython import display\n",
48 "import networkx as nx\n",
49 "import numpy as np\n",
50 "import time\n",
51 "!cp ~/work/SimHei.ttf /home/jovyan/.virtualenvs/basenv/lib/python3.5/site-packages/matplotlib/mpl-data/fonts/ttf\n",
52 "plt.rcParams['font.sans-serif']=['SimHei']\n",
53 "plt.rcParams['axes.unicode_minus']=False"
54 ]
55 },
56 {
57 "cell_type": "code",
58 "execution_count": null,
59 "metadata": {},
60 "outputs": [],
61 "source": [
62 "# 定义节点列表\n",
63 "node_list = ['A', 'B', 'C', 'D', 'E', 'F', 'G']\n",
64 "\n",
65 "# 定义边及权重列表\n",
66 "weighted_edges_list = [('A', 'B', 8), ('A', 'C', 20),\n",
67 " ('B', 'F', 40), ('B', 'E', 30),('B', 'D', 20),\n",
68 " ('C', 'D', 10), \n",
69 " ('D', 'G', 10), ('D', 'E', 10),\n",
70 " ('E', 'F', 30), \n",
71 " ('F', 'G', 30)]\n",
72 "# 定义绘图中各个节点的坐标\n",
73 "nodes_pos = {\"A\": (1, 1), \"B\": (3, 3), \"C\": (5, 0), \"D\": (9, 2),\n",
74 " \"E\": (7, 4), \"F\": (6,6),\"G\": (11,5)}\n",
75 "\n",
76 "# 绘制无向图\n",
77 "g = SearchGraph(node_list, weighted_edges_list, 'A', 'G',nodes_pos)\n",
78 "g.show_graph()"
79 ]
80 },
81 {
82 "cell_type": "markdown",
83 "metadata": {},
84 "source": [
85 "\n",
86 "观察上图,可以看到从起点 A 到目标点 G 距离最短的路径是 A -> B -> D -> G,其距离是 38,我们可以设计一个计算机程序,按照既定的规则,从起点 A 出发,不断尝试从一个节点移动到下一个节点,直到抵达目标点 G。\n",
87 "\n",
88 "在详细描述搜索算法之前,先看看下面四个重要的概念。\n",
89 "\n",
90 "+ **状态**。状态可以认为是搜索算法在某一时刻所处的位置,相应地,搜索算法在开始和结束时所处的位置成为**初始状态**和**终止状态**。\n",
91 "\n"
92 ]
93 },
94 {
95 "cell_type": "markdown",
96 "metadata": {},
97 "source": [
98 "+ **测试目标**。用于判断当前状态是不是目标状态。例如在此问题中目标点是 G,因此目标测试只需要判断当前状态是否为 G 即可。当然,即使到达了目标状态,找到路径也未必是代价最小的。"
99 ]
100 },
101 {
102 "cell_type": "markdown",
103 "metadata": {},
104 "source": [
105 "+ **动作**。动作指的是搜索算法从一个状态转变到另外一个状态所采取的行为。一般假设在每个状态下所能够采取的行为数量都是有限的。例如:在起点 A,只有 B 和 C 两个节点与之相连,所以只有转移到 B 或者转移到 C 这两种选择。一般情况从一个状态到另外一个状态的过程叫做**状态转移**。"
106 ]
107 },
108 {
109 "cell_type": "markdown",
110 "metadata": {},
111 "source": [
112 "下图中,我们在初始状态采取了转移到 B 这个动作。"
113 ]
114 },
115 {
116 "cell_type": "code",
117 "execution_count": null,
118 "metadata": {},
119 "outputs": [],
120 "source": [
121 "g.show_graph(this_path=\"AB\")"
122 ]
123 },
124 {
125 "cell_type": "markdown",
126 "metadata": {},
127 "source": [
128 "+ **路径**。完成一系列连续的状态转移所得到的状态序列,就构成了从起点到终点的路径,如从状态 A 到状态 B,接着到状态 D,最后到状态 G,就形成了A -> B -> D -> G 这样的一条路径。很显然,路径的总代价,等于路径上各个节点之间代价的总和。在路径搜索问题中,任何一条路径的代价都不会是负数。\n",
129 "\n"
130 ]
131 },
132 {
133 "cell_type": "code",
134 "execution_count": null,
135 "metadata": {},
136 "outputs": [],
137 "source": [
138 "g.show_graph(this_path=\"ABDG\")"
139 ]
140 },
141 {
142 "cell_type": "markdown",
143 "metadata": {},
144 "source": [
145 "## 5.1.2 搜索算法"
146 ]
147 },
148 {
149 "cell_type": "markdown",
150 "metadata": {},
151 "source": [
152 "搜索算法就是不断从某一状态转移到下一状态,直至到达终止状态为止。\n",
153 "\n",
154 "\n",
155 "在搜索算法中,从当前状态出发寻找后续节点,一般会面临多种选择。比如上例中,从 A 出发,可以选择 B 或者 C;从 B 出发,可以选择 A、D、E 或者 F;从 C 出发,可以选择 A 或 D,可见状态之间这种转移构成的如下图的分层树状结构,该结构称为**搜索树**。"
156 ]
157 },
158 {
159 "cell_type": "code",
160 "execution_count": null,
161 "metadata": {},
162 "outputs": [],
163 "source": [
164 "g.show_search_tree()"
165 ]
166 },
167 {
168 "cell_type": "markdown",
169 "metadata": {},
170 "source": [
171 "在搜索树中,每个结点可用一个状态来标记,表示从根节点出发,经过怎样的路径到达该节点,两个节点之间的连线表示这两个节点之间存在状态转移。搜索树状态中,由于每个结点这种状态来标记,因此可能存在两个拥有相同标记的节点,但其含义不同。\n",
172 "\n",
173 "*注意:路径搜索不能出现回路。*\n",
174 "\n",
175 "搜索算法从初始节点出发,不断选择后续节点,完成了搜索树的构造。一开始,搜索树中只有根节点,在每一步中搜索算法将选择与搜索树中某个节点相邻的一个后续节点加入搜索树,这个操作叫做**扩展一个节点**。\n",
176 "\n",
177 "能够扩展的节点需满足条件:\n",
178 "+ 该节点不能已经在搜索树中,即该节点不能已经被扩展过;\n",
179 "+ 该节点能够从搜索树中某个节点出花通过执行一个动作抵达,被扩展节点和搜索树节点的某个节点是相邻的。 \n",
180 "\n",
181 "这些能够被扩展的节点构成的集合称为未访问节点集合。\n",
182 "\n",
183 "于是,搜索算法的每一步操作可以做如下描述: \n",
184 "每次选择未访问节点集合中的一个节点加入当前搜索树,检查这个节点所有后续相邻节点,将满足条件的节点加入未访问节点集合中,重复执行上述操作,直至被扩展的节点对应一条从初始节点到终止结点的路径。"
185 ]
186 },
187 {
188 "cell_type": "markdown",
189 "metadata": {},
190 "source": [
191 "## 5.1.3 深度优先搜索和广度优先搜索"
192 ]
193 },
194 {
195 "cell_type": "markdown",
196 "metadata": {},
197 "source": [
198 "**深度优先搜索**总是沿着某个分支进行搜索、直至不能再深入为止,即优先扩展搜索树当前未访问节点集合中最深的节点。深度优先搜索算法在搜索过程中总是倾向于沿着一条分支前进,直到该分支上所有的节点都被访问完,再返回上一层进行另一轮深度优先搜索。"
199 ]
200 },
201 {
202 "cell_type": "code",
203 "execution_count": null,
204 "metadata": {},
205 "outputs": [],
206 "source": [
207 "g.animation_search_tree('dfs')"
208 ]
209 },
210 {
211 "cell_type": "markdown",
212 "metadata": {},
213 "source": [
214 "**广度优先搜索**总是优先扩展为访问节点集合中最浅的节点,在执行中倾向于优先把同一层的所有可能节点访问完后再考虑进行更深的探索。"
215 ]
216 },
217 {
218 "cell_type": "code",
219 "execution_count": null,
220 "metadata": {},
221 "outputs": [],
222 "source": [
223 "g.animation_search_tree('bfs')"
224 ]
225 },
226 {
227 "cell_type": "markdown",
228 "metadata": {},
229 "source": [
230 "需要强调的是,对于一个搜索问题,只要存在答案(即从初始节点到终止节点存在满足条件的一条路径),那么排除了回路的深度优先搜索和广度优先搜索均能找到一个答案,但是这个找到的答案不一定是最优的,例如距离最短。"
231 ]
232 },
233 {
234 "cell_type": "markdown",
235 "metadata": {},
236 "source": [
237 "## 5.1.4 启发式搜索"
238 ]
239 },
240 {
241 "cell_type": "markdown",
242 "metadata": {},
243 "source": [
244 "在搜索过程中利用问题的定义以外**辅助信息**的搜索算法称为**启发式搜索算法**,或者叫**有信息的搜索算法**。"
245 ]
246 },
247 {
248 "cell_type": "markdown",
249 "metadata": {},
250 "source": [
251 "在路径搜索问题中,可引入任意一个节点与目标节点之间的直线距离,作为辅助信息,来提升搜索算法的效率。根据这一想法,可以设计一个直观的最短路径搜索算法:算法从初始节点开始,每一步都将未访问节点集合中离目标节点直线距离最近的节点加入搜索树,直至到达目标站点,这个算法称为**贪婪最佳优先搜索**。\n",
252 "\n",
253 "**辅助信息:各个节点到目标节点G的直线距离**\n",
254 "\n",
255 "|站点|A|B|C|D|E|F|G|\n",
256 "|--|--|--|--|--|--|--|--|\n",
257 "|距离|30|20|19|10|5|25|0|\n",
258 "\n",
259 "贪婪最佳优先算法搜索过程如下:"
260 ]
261 },
262 {
263 "cell_type": "code",
264 "execution_count": null,
265 "metadata": {},
266 "outputs": [],
267 "source": [
268 "# 为搜索算法提供辅助信息\n",
269 "g.help_info = {'A': 30, 'B': 20, 'C': 19, 'D':10, 'E':5, 'F':25, 'G': 0}\n",
270 "# 动态演示贪婪搜索\n",
271 "g.animation_search_tree('greedy')"
272 ]
273 },
274 {
275 "cell_type": "markdown",
276 "metadata": {},
277 "source": [
278 "但是在“贪婪”机制下找到的路径 A -> C -> D -> G 并非最短路径。产生这样的搜索结果,其原因是:最佳优先算法在当前节点时,每次均贪婪的从当前节点相邻的节点中选择**与目标节点直线距离最近的节点**,作为后续节点。这样就会造成贪婪最佳优先算法**过于重视当前的最优,而忽视了全局最优**。\n"
279 ]
280 },
281 {
282 "cell_type": "markdown",
283 "metadata": {},
284 "source": [
285 "另一种启发式搜索算法—— A\\* 算法克服了这一不足。\n",
286 "\n",
287 "其算法思路是:将初始节点到目标节点的距离分成两部分,\n",
288 "- (1)初始节点到当前节点的路径代价;\n",
289 "- (2)当前节点到目标节点之间的直线距离。将两者之和作为评价函数的取值大小。\n",
290 "\n",
291 "具体而言,对于未访问节点集合中某个节点 n,A\\* 算法评价节点 n 取值大小的评价函数 f(n) 由两部分构成:\n",
292 "+ 函数 g(n): 表示从初始节点到当前节点 n 的实际距离,\n",
293 "+ 函数 h(n): 表示当前节点 n 到目标节点的直线距离。函数 h(n) 也称为**启发函数**。\n",
294 "\n",
295 "\n",
296 " A\\*算法搜索过程:\n"
297 ]
298 },
299 {
300 "cell_type": "code",
301 "execution_count": null,
302 "metadata": {},
303 "outputs": [],
304 "source": [
305 "# 为搜索算法提供辅助信息\n",
306 "g.help_info = {'A': 30, 'B': 20, 'C': 19, 'D':10, 'E':5, 'F':25, 'G': 0}\n",
307 "# 动态演示 A* 算法\n",
308 "g.animation_search_tree('a_star')"
309 ]
310 },
311 {
312 "cell_type": "code",
313 "execution_count": null,
314 "metadata": {},
315 "outputs": [],
316 "source": [
317 "# 可以调整辅助信息的比重\n",
318 "# 当只考虑额外信息时,即 origin_info_weight 设置为 0 的时候,A* 算法退化为贪婪算法。\n",
319 "g.animation_search_tree('a_star',help_info_weight=1, origin_info_weight=0)"
320 ]
321 },
322 {
323 "cell_type": "markdown",
324 "metadata": {},
325 "source": [
326 "与贪婪最佳优先算法不一定能够找到最短路径不同,A\\* 算法找到的路径一定是最短路径;另一方面,由于A\\* 算法能够利用辅助信息,因此它比其他算法用更少的步骤。\n",
327 "\n",
328 "在实际中,A\\* 算法的性能表现取决于启发函数的设计,只要定义一个合适的启发函数,A\\* 算法就能够大幅缩减搜索所需的时间。"
329 ]
330 },
331 {
332 "cell_type": "markdown",
333 "metadata": {},
334 "source": [
335 "#### 思考与练习"
336 ]
337 },
338 {
339 "cell_type": "markdown",
340 "metadata": {},
341 "source": [
342 "下图是一张线路示意图。"
343 ]
344 },
345 {
346 "cell_type": "code",
347 "execution_count": null,
348 "metadata": {},
349 "outputs": [],
350 "source": [
351 "node_list = [\"0\",\"1\",\"2\",\"3\",\"4\"]\n",
352 "weighted_edges_list = [(\"0\",\"1\",10), (\"0\",\"2\",10),\n",
353 " (\"1\", \"3\", 10), \n",
354 " (\"2\", \"3\", 5), (\"2\", \"4\", 20),\n",
355 " (\"3\", \"4\", 14),(\"3\", \"2\", 5)]\n",
356 "nodes_pos = {\"0\":(1,7),\"1\":(5,1),\"2\":(5,13),\"3\":(9,7),\"4\":(11,13)}\n",
357 "h_graph = SearchGraph(node_list, weighted_edges_list, \"0\", \"4\",nodes_pos)\n",
358 "h_graph.show_graph()\n"
359 ]
360 },
361 {
362 "cell_type": "markdown",
363 "metadata": {},
364 "source": [
365 "如果使用深度优先搜索求状态 0 到状态 4 的一条路径,我们可以用下表来模拟搜索过程。注意:在下表中,结点的深度定义为它对应路径中状态转移的次数,如果多个未访问结点的深度相同,那么在这个例子里算法优先选择状态编号大的节点。\n",
366 "\n",
367 "|步骤|当前状态|当前未访问节点集合(用上划线标出了下一个扩展的节点)|\n",
368 "|:--:|:--:|:--|\n",
369 "|1|0|深度1:${0 -> 1,\\overline{0 -> 2}}$|\n",
370 "|2|2|深度1:${0 -> 1}$ 深度2:$\\underline{(1)}$|\n",
371 "|3|$\\underline{(2)}$|找到路径:0 -> 2 -> 4|\n",
372 "\n",
373 "请仔细观察上表中各项内容的含义,根据深度优先搜素的思路,在横线(1)和(2)处填写内容。问找到的路径0->2->4是代价最小的吗?"
374 ]
375 },
376 {
377 "cell_type": "markdown",
378 "metadata": {},
379 "source": [
380 "请在下方作答:"
381 ]
382 },
383 {
384 "cell_type": "markdown",
385 "metadata": {},
386 "source": [
387 "答案 1:\n",
388 "\n",
389 "答案 2:\n"
390 ]
391 }
392 ],
393 "metadata": {
394 "kernelspec": {
395 "display_name": "Python 3",
396 "language": "python",
397 "name": "python3"
398 },
399 "language_info": {
400 "codemirror_mode": {
401 "name": "ipython",
402 "version": 3
403 },
404 "file_extension": ".py",
405 "mimetype": "text/x-python",
406 "name": "python",
407 "nbconvert_exporter": "python",
408 "pygments_lexer": "ipython3",
409 "version": "3.5.2"
410 }
411 },
412 "nbformat": 4,
413 "nbformat_minor": 2
414 }
0 {
1 "cells": [
2 {
3 "cell_type": "markdown",
4 "metadata": {},
5 "source": [
6 "# 5.1 基于搜索的问题求解"
7 ]
8 },
9 {
10 "cell_type": "markdown",
11 "metadata": {},
12 "source": [
13 "现实世界中许多问题都可以通过搜索的方法来求解,例如设计最佳出行路线或是制订合理的课程表。当给定一个待求解问题后,搜索算法会按照事先设定的逻辑来自动寻找符合求解问题的答案,因此一般可将搜索算法称为问题求解智能体。"
14 ]
15 },
16 {
17 "cell_type": "markdown",
18 "metadata": {},
19 "source": [
20 "## 5.1.1 搜索算法基本概念"
21 ]
22 },
23 {
24 "cell_type": "markdown",
25 "metadata": {},
26 "source": [
27 "我们把书中的公交换乘的问题,转为无向图中的的路径寻找问题。\n",
28 "\n",
29 "首先,我们画出如下的无向图。该无向图中有 A,B,C,D,E,F,G 七个节点,其中 A 是起点, G 是目标点。\n",
30 "\n",
31 "点与点之间的连线称为边,边可以有权重,可以代表点与点之间的距离或者从一个点转移到另一个点需要花费的代价。\n",
32 "\n",
33 "下面我们先创建一个图。"
34 ]
35 },
36 {
37 "cell_type": "code",
38 "execution_count": null,
39 "metadata": {},
40 "outputs": [],
41 "source": [
42 "# 首先导入必要的包\n",
43 "from search import SearchGraph\n",
44 "import collections\n",
45 "import matplotlib.pyplot as plt\n",
46 "import collections\n",
47 "from IPython import display\n",
48 "import networkx as nx\n",
49 "import numpy as np\n",
50 "import time\n",
51 "!cp ~/work/SimHei.ttf /home/jovyan/.virtualenvs/basenv/lib/python3.5/site-packages/matplotlib/mpl-data/fonts/ttf\n",
52 "plt.rcParams['font.sans-serif']=['SimHei']\n",
53 "plt.rcParams['axes.unicode_minus']=False"
54 ]
55 },
56 {
57 "cell_type": "code",
58 "execution_count": null,
59 "metadata": {},
60 "outputs": [],
61 "source": [
62 "# 定义节点列表\n",
63 "node_list = ['A', 'B', 'C', 'D', 'E', 'F', 'G']\n",
64 "\n",
65 "# 定义边及权重列表\n",
66 "weighted_edges_list = [('A', 'B', 8), ('A', 'C', 20),\n",
67 " ('B', 'F', 40), ('B', 'E', 30),('B', 'D', 20),\n",
68 " ('C', 'D', 10), \n",
69 " ('D', 'G', 10), ('D', 'E', 10),\n",
70 " ('E', 'F', 30), \n",
71 " ('F', 'G', 30)]\n",
72 "# 定义绘图中各个节点的坐标\n",
73 "nodes_pos = {\"A\": (1, 1), \"B\": (3, 3), \"C\": (5, 0), \"D\": (9, 2),\n",
74 " \"E\": (7, 4), \"F\": (6,6),\"G\": (11,5)}\n",
75 "\n",
76 "# 绘制无向图\n",
77 "g = SearchGraph(node_list, weighted_edges_list, 'A', 'G',nodes_pos)\n",
78 "g.show_graph()"
79 ]
80 },
81 {
82 "cell_type": "markdown",
83 "metadata": {},
84 "source": [
85 "\n",
86 "观察上图,可以看到从起点 A 到目标点 G 距离最短的路径是 A -> B -> D -> G,其距离是 38,我们可以设计一个计算机程序,按照既定的规则,从起点 A 出发,不断尝试从一个节点移动到下一个节点,直到抵达目标点 G。\n",
87 "\n",
88 "在详细描述搜索算法之前,先看看下面四个重要的概念。\n",
89 "\n",
90 "+ **状态**。状态可以认为是搜索算法在某一时刻所处的位置,相应地,搜索算法在开始和结束时所处的位置成为**初始状态**和**终止状态**。\n",
91 "\n"
92 ]
93 },
94 {
95 "cell_type": "markdown",
96 "metadata": {},
97 "source": [
98 "+ **测试目标**。用于判断当前状态是不是目标状态。例如在此问题中目标点是 G,因此目标测试只需要判断当前状态是否为 G 即可。当然,即使到达了目标状态,找到路径也未必是代价最小的。"
99 ]
100 },
101 {
102 "cell_type": "markdown",
103 "metadata": {},
104 "source": [
105 "+ **动作**。动作指的是搜索算法从一个状态转变到另外一个状态所采取的行为。一般假设在每个状态下所能够采取的行为数量都是有限的。例如:在起点 A,只有 B 和 C 两个节点与之相连,所以只有转移到 B 或者转移到 C 这两种选择。一般情况从一个状态到另外一个状态的过程叫做**状态转移**。"
106 ]
107 },
108 {
109 "cell_type": "markdown",
110 "metadata": {},
111 "source": [
112 "下图中,我们在初始状态采取了转移到 B 这个动作。"
113 ]
114 },
115 {
116 "cell_type": "code",
117 "execution_count": null,
118 "metadata": {},
119 "outputs": [],
120 "source": [
121 "g.show_graph(this_path=\"AB\")"
122 ]
123 },
124 {
125 "cell_type": "markdown",
126 "metadata": {},
127 "source": [
128 "+ **路径**。完成一系列连续的状态转移所得到的状态序列,就构成了从起点到终点的路径,如从状态 A 到状态 B,接着到状态 D,最后到状态 G,就形成了A -> B -> D -> G 这样的一条路径。很显然,路径的总代价,等于路径上各个节点之间代价的总和。在路径搜索问题中,任何一条路径的代价都不会是负数。\n",
129 "\n"
130 ]
131 },
132 {
133 "cell_type": "code",
134 "execution_count": null,
135 "metadata": {},
136 "outputs": [],
137 "source": [
138 "g.show_graph(this_path=\"ABDG\")"
139 ]
140 },
141 {
142 "cell_type": "markdown",
143 "metadata": {},
144 "source": [
145 "## 5.1.2 搜索算法"
146 ]
147 },
148 {
149 "cell_type": "markdown",
150 "metadata": {},
151 "source": [
152 "搜索算法就是不断从某一状态转移到下一状态,直至到达终止状态为止。\n",
153 "\n",
154 "\n",
155 "在搜索算法中,从当前状态出发寻找后续节点,一般会面临多种选择。比如上例中,从 A 出发,可以选择 B 或者 C;从 B 出发,可以选择 A、D、E 或者 F;从 C 出发,可以选择 A 或 D,可见状态之间这种转移构成的如下图的分层树状结构,该结构称为**搜索树**。"
156 ]
157 },
158 {
159 "cell_type": "code",
160 "execution_count": null,
161 "metadata": {},
162 "outputs": [],
163 "source": [
164 "g.show_search_tree()"
165 ]
166 },
167 {
168 "cell_type": "markdown",
169 "metadata": {},
170 "source": [
171 "在搜索树中,每个结点可用一个状态来标记,表示从根节点出发,经过怎样的路径到达该节点,两个节点之间的连线表示这两个节点之间存在状态转移。搜索树状态中,由于每个结点这种状态来标记,因此可能存在两个拥有相同标记的节点,但其含义不同。\n",
172 "\n",
173 "*注意:路径搜索不能出现回路。*\n",
174 "\n",
175 "搜索算法从初始节点出发,不断选择后续节点,完成了搜索树的构造。一开始,搜索树中只有根节点,在每一步中搜索算法将选择与搜索树中某个节点相邻的一个后续节点加入搜索树,这个操作叫做**扩展一个节点**。\n",
176 "\n",
177 "能够扩展的节点需满足条件:\n",
178 "+ 该节点不能已经在搜索树中,即该节点不能已经被扩展过;\n",
179 "+ 该节点能够从搜索树中某个节点出花通过执行一个动作抵达,被扩展节点和搜索树节点的某个节点是相邻的。 \n",
180 "\n",
181 "这些能够被扩展的节点构成的集合称为未访问节点集合。\n",
182 "\n",
183 "于是,搜索算法的每一步操作可以做如下描述: \n",
184 "每次选择未访问节点集合中的一个节点加入当前搜索树,检查这个节点所有后续相邻节点,将满足条件的节点加入未访问节点集合中,重复执行上述操作,直至被扩展的节点对应一条从初始节点到终止结点的路径。"
185 ]
186 },
187 {
188 "cell_type": "markdown",
189 "metadata": {},
190 "source": [
191 "## 5.1.3 深度优先搜索和广度优先搜索"
192 ]
193 },
194 {
195 "cell_type": "markdown",
196 "metadata": {},
197 "source": [
198 "**深度优先搜索**总是沿着某个分支进行搜索、直至不能再深入为止,即优先扩展搜索树当前未访问节点集合中最深的节点。深度优先搜索算法在搜索过程中总是倾向于沿着一条分支前进,直到该分支上所有的节点都被访问完,再返回上一层进行另一轮深度优先搜索。"
199 ]
200 },
201 {
202 "cell_type": "code",
203 "execution_count": null,
204 "metadata": {},
205 "outputs": [],
206 "source": [
207 "g.animation_search_tree('dfs')"
208 ]
209 },
210 {
211 "cell_type": "markdown",
212 "metadata": {},
213 "source": [
214 "**广度优先搜索**总是优先扩展为访问节点集合中最浅的节点,在执行中倾向于优先把同一层的所有可能节点访问完后再考虑进行更深的探索。"
215 ]
216 },
217 {
218 "cell_type": "code",
219 "execution_count": null,
220 "metadata": {},
221 "outputs": [],
222 "source": [
223 "g.animation_search_tree('bfs')"
224 ]
225 },
226 {
227 "cell_type": "markdown",
228 "metadata": {},
229 "source": [
230 "需要强调的是,对于一个搜索问题,只要存在答案(即从初始节点到终止节点存在满足条件的一条路径),那么排除了回路的深度优先搜索和广度优先搜索均能找到一个答案,但是这个找到的答案不一定是最优的,例如距离最短。"
231 ]
232 },
233 {
234 "cell_type": "markdown",
235 "metadata": {},
236 "source": [
237 "## 5.1.4 启发式搜索"
238 ]
239 },
240 {
241 "cell_type": "markdown",
242 "metadata": {},
243 "source": [
244 "在搜索过程中利用问题的定义以外**辅助信息**的搜索算法称为**启发式搜索算法**,或者叫**有信息的搜索算法**。"
245 ]
246 },
247 {
248 "cell_type": "markdown",
249 "metadata": {},
250 "source": [
251 "在路径搜索问题中,可引入任意一个节点与目标节点之间的直线距离,作为辅助信息,来提升搜索算法的效率。根据这一想法,可以设计一个直观的最短路径搜索算法:算法从初始节点开始,每一步都将未访问节点集合中离目标节点直线距离最近的节点加入搜索树,直至到达目标站点,这个算法称为**贪婪最佳优先搜索**。\n",
252 "\n",
253 "**辅助信息:各个节点到目标节点G的直线距离**\n",
254 "\n",
255 "|站点|A|B|C|D|E|F|G|\n",
256 "|--|--|--|--|--|--|--|--|\n",
257 "|距离|30|20|19|10|5|25|0|\n",
258 "\n",
259 "贪婪最佳优先算法搜索过程如下:"
260 ]
261 },
262 {
263 "cell_type": "code",
264 "execution_count": null,
265 "metadata": {},
266 "outputs": [],
267 "source": [
268 "# 为搜索算法提供辅助信息\n",
269 "g.help_info = {'A': 30, 'B': 20, 'C': 19, 'D':10, 'E':5, 'F':25, 'G': 0}\n",
270 "# 动态演示贪婪搜索\n",
271 "g.animation_search_tree('greedy')"
272 ]
273 },
274 {
275 "cell_type": "markdown",
276 "metadata": {},
277 "source": [
278 "但是在“贪婪”机制下找到的路径 A -> C -> D -> G 并非最短路径。产生这样的搜索结果,其原因是:最佳优先算法在当前节点时,每次均贪婪的从当前节点相邻的节点中选择**与目标节点直线距离最近的节点**,作为后续节点。这样就会造成贪婪最佳优先算法**过于重视当前的最优,而忽视了全局最优**。\n"
279 ]
280 },
281 {
282 "cell_type": "markdown",
283 "metadata": {},
284 "source": [
285 "另一种启发式搜索算法—— A\\* 算法克服了这一不足。\n",
286 "\n",
287 "其算法思路是:将初始节点到目标节点的距离分成两部分,\n",
288 "- (1)初始节点到当前节点的路径代价;\n",
289 "- (2)当前节点到目标节点之间的直线距离。将两者之和作为评价函数的取值大小。\n",
290 "\n",
291 "具体而言,对于未访问节点集合中某个节点 n,A\\* 算法评价节点 n 取值大小的评价函数 f(n) 由两部分构成:\n",
292 "+ 函数 g(n): 表示从初始节点到当前节点 n 的实际距离,\n",
293 "+ 函数 h(n): 表示当前节点 n 到目标节点的直线距离。函数 h(n) 也称为**启发函数**。\n",
294 "\n",
295 "\n",
296 " A\\*算法搜索过程:\n"
297 ]
298 },
299 {
300 "cell_type": "code",
301 "execution_count": null,
302 "metadata": {},
303 "outputs": [],
304 "source": [
305 "# 为搜索算法提供辅助信息\n",
306 "g.help_info = {'A': 30, 'B': 20, 'C': 19, 'D':10, 'E':5, 'F':25, 'G': 0}\n",
307 "# 动态演示 A* 算法\n",
308 "g.animation_search_tree('a_star')"
309 ]
310 },
311 {
312 "cell_type": "code",
313 "execution_count": null,
314 "metadata": {},
315 "outputs": [],
316 "source": [
317 "# 可以调整辅助信息的比重\n",
318 "# 当只考虑额外信息时,即 origin_info_weight 设置为 0 的时候,A* 算法退化为贪婪算法。\n",
319 "g.animation_search_tree('a_star',help_info_weight=1, origin_info_weight=0)"
320 ]
321 },
322 {
323 "cell_type": "markdown",
324 "metadata": {},
325 "source": [
326 "与贪婪最佳优先算法不一定能够找到最短路径不同,A\\* 算法找到的路径一定是最短路径;另一方面,由于A\\* 算法能够利用辅助信息,因此它比其他算法用更少的步骤。\n",
327 "\n",
328 "在实际中,A\\* 算法的性能表现取决于启发函数的设计,只要定义一个合适的启发函数,A\\* 算法就能够大幅缩减搜索所需的时间。"
329 ]
330 },
331 {
332 "cell_type": "markdown",
333 "metadata": {},
334 "source": [
335 "#### 思考与练习"
336 ]
337 },
338 {
339 "cell_type": "markdown",
340 "metadata": {},
341 "source": [
342 "下图是一张线路示意图。"
343 ]
344 },
345 {
346 "cell_type": "code",
347 "execution_count": null,
348 "metadata": {},
349 "outputs": [],
350 "source": [
351 "node_list = [\"0\",\"1\",\"2\",\"3\",\"4\"]\n",
352 "weighted_edges_list = [(\"0\",\"1\",10), (\"0\",\"2\",10),\n",
353 " (\"1\", \"3\", 10), \n",
354 " (\"2\", \"3\", 5), (\"2\", \"4\", 20),\n",
355 " (\"3\", \"4\", 14),(\"3\", \"2\", 5)]\n",
356 "nodes_pos = {\"0\":(1,7),\"1\":(5,1),\"2\":(5,13),\"3\":(9,7),\"4\":(11,13)}\n",
357 "h_graph = SearchGraph(node_list, weighted_edges_list, \"0\", \"4\",nodes_pos)\n",
358 "h_graph.show_graph()\n"
359 ]
360 },
361 {
362 "cell_type": "markdown",
363 "metadata": {},
364 "source": [
365 "如果使用深度优先搜索求状态 0 到状态 4 的一条路径,我们可以用下表来模拟搜索过程。注意:在下表中,结点的深度定义为它对应路径中状态转移的次数,如果多个未访问结点的深度相同,那么在这个例子里算法优先选择状态编号大的节点。\n",
366 "\n",
367 "|步骤|当前状态|当前未访问节点集合(用上划线标出了下一个扩展的节点)|\n",
368 "|:--:|:--:|:--|\n",
369 "|1|0|深度1:${0 -> 1,\\overline{0 -> 2}}$|\n",
370 "|2|2|深度1:${0 -> 1}$ 深度2:$\\underline{(1)}$|\n",
371 "|3|$\\underline{(2)}$|找到路径:0 -> 2 -> 4|\n",
372 "\n",
373 "请仔细观察上表中各项内容的含义,根据深度优先搜素的思路,在横线(1)和(2)处填写内容。问找到的路径0->2->4是代价最小的吗?"
374 ]
375 },
376 {
377 "cell_type": "markdown",
378 "metadata": {},
379 "source": [
380 "答案 1:\n",
381 "\n",
382 "(1)$2->3,\\overline{2->4}$ \n",
383 "(2)4 \n",
384 "\n",
385 "答案 2:\n",
386 "\n",
387 "不是,代价最小为29,其路径为:0->2->3->4。 "
388 ]
389 },
390 {
391 "cell_type": "code",
392 "execution_count": null,
393 "metadata": {},
394 "outputs": [],
395 "source": [
396 "# 查看 dfs 的搜索过程\n",
397 "h_graph.animation_search_tree('dfs')"
398 ]
399 },
400 {
401 "cell_type": "code",
402 "execution_count": null,
403 "metadata": {},
404 "outputs": [],
405 "source": []
406 }
407 ],
408 "metadata": {
409 "kernelspec": {
410 "display_name": "Python 3",
411 "language": "python",
412 "name": "python3"
413 },
414 "language_info": {
415 "codemirror_mode": {
416 "name": "ipython",
417 "version": 3
418 },
419 "file_extension": ".py",
420 "mimetype": "text/x-python",
421 "name": "python",
422 "nbconvert_exporter": "python",
423 "pygments_lexer": "ipython3",
424 "version": "3.5.2"
425 }
426 },
427 "nbformat": 4,
428 "nbformat_minor": 2
429 }
0 {
1 "cells": [
2 {
3 "cell_type": "markdown",
4 "metadata": {},
5 "source": [
6 "# 5.2 决策树"
7 ]
8 },
9 {
10 "cell_type": "markdown",
11 "metadata": {},
12 "source": [
13 "决策树是一种通过**树形结构**进行分类的方法。在决策树中,树形结构中每个节点表示对分类目标在属性上的一个判断,每个分支代表基于该属性做出的一个判断,最后树形结构中每个叶子结点代表一种分类结果。"
14 ]
15 },
16 {
17 "cell_type": "markdown",
18 "metadata": {},
19 "source": [
20 "## 5.2.1 决策树分类概念"
21 ]
22 },
23 {
24 "cell_type": "markdown",
25 "metadata": {},
26 "source": [
27 "**数据**: 游乐场经营者提供**天气情况**(如晴、雨、多云)、**温度高低**、**湿度大小**、**风力强弱**等气象特点以及游客当天是否前往游乐场。\n",
28 "\n",
29 "**目标**: 预测游客是否来游乐场游玩。\n",
30 "\n",
31 "\n",
32 "|序号|天气|温度(℃)|湿度|是否有风|是(1)否(0)前往游乐场|\n",
33 "|:--:|:--:|:--:|:--:|:--:|:--:|\n",
34 "|1|晴|29|85|否|0|\n",
35 "|2|晴|26|88|是|0|\n",
36 "|3|多云|28|78|否|1\n",
37 "|4|雨|21|96|否|1|\n",
38 "|5|雨|20|80|否|1|\n",
39 "|6|雨|18|70|是|0|\n",
40 "|7|多云|18|65|是|1|\n",
41 "|8|晴|22|90|否|0|\n",
42 "|9|晴|21|68|否|1|\n",
43 "|10|雨|24|80|否|1|\n",
44 "|11|晴|24|63|是|1|\n",
45 "|12|多云|22|90|是|1|\n",
46 "|13|多云|27|75|否|1|\n",
47 "|14|雨|21|80|是|0|\n",
48 "\n",
49 "根据上表,绘制如图所示的决策树:\n",
50 "\n",
51 "<img src=\"http://imgbed.momodel.cn/决策树11.png\" width=500px>"
52 ]
53 },
54 {
55 "cell_type": "markdown",
56 "metadata": {},
57 "source": [
58 "第一层是天气状况,具有雨、多云和晴三种属性取值。\n",
59 "+ 多云: 样本子集是 { 3, 7, 12, 13 } ,仅有“前往游乐场游玩”一个类别,即肯定去游乐场。 \n",
60 " \n",
61 " \n",
62 "+ 晴: 样本子集是 { 1, 2, 8, 9, 11 }\n",
63 " + 湿度大于 75:样本子集为 { 1, 2, 8 },不前往游乐场。\n",
64 " + 湿度不大于 75:样本子集 { 9, 11 },前往游乐场。\n",
65 " \n",
66 " \n",
67 "+ 雨:样本子集为 { 4, 5, 6, 10, 14 }\n",
68 " + 有风:样本子集 { 6, 14 },不去游乐场。\n",
69 " + 无风:样本子集 { 4, 5, 10 },前往游乐场。\n",
70 " "
71 ]
72 },
73 {
74 "cell_type": "markdown",
75 "metadata": {},
76 "source": [
77 "由上面的例子可以看到,构建决策树的过程就是:\n",
78 "1. 选择一个属性值;\n",
79 "2. 基于该属性对样本集进行划分;\n",
80 "3. 重复步骤 1 和 2 直到最后所得划分结果中每个样本为同一类别。"
81 ]
82 },
83 {
84 "cell_type": "code",
85 "execution_count": null,
86 "metadata": {},
87 "outputs": [],
88 "source": [
89 "import numpy as np\n",
90 "import pandas as pd\n",
91 "import matplotlib.pyplot as plt\n",
92 "%matplotlib inline\n",
93 "\n",
94 "import math\n",
95 "from math import log\n",
96 "import warnings\n",
97 "warnings.filterwarnings(\"ignore\")"
98 ]
99 },
100 {
101 "cell_type": "code",
102 "execution_count": null,
103 "metadata": {},
104 "outputs": [],
105 "source": [
106 "# 原始数据\n",
107 "datasets = [\n",
108 " ['晴',29,85,'否','0'],\n",
109 " ['晴',26,88,'是','0'],\n",
110 " ['多云',28,78,'否','1'],\n",
111 " ['雨',21,96,'否','1'],\n",
112 " ['雨',20,80,'否','1'],\n",
113 " ['雨',18,70,'是','0'],\n",
114 " ['多云',18,65,'是','1'],\n",
115 " ['晴',22,90,'否','0'],\n",
116 " ['晴',21,68,'否','1'],\n",
117 " ['雨',24,80,'否','1'],\n",
118 " ['晴',24,63,'是','1'],\n",
119 " ['多云',22,90,'是','1'],\n",
120 " ['多云',27,75,'否','1'],\n",
121 " ['雨',21,80,'是','0']\n",
122 "]\n",
123 "\n",
124 "# 数据的列名\n",
125 "labels = ['天气','温度','湿度','是否有风','是否前往游乐场']\n",
126 "\n",
127 "# 将湿度大小分为大于 75 和小于等于 75 这两个属性值,\n",
128 "# 将温度大小分为大于 26 和小于等于 26 这两个属性值\n",
129 "for i in range(len(datasets)):\n",
130 " if datasets[i][2] > 75:\n",
131 " datasets[i][2] = '>75'\n",
132 " else:\n",
133 " datasets[i][2] = '<=75'\n",
134 " if datasets[i][1] > 26:\n",
135 " datasets[i][1] = '>26'\n",
136 " else:\n",
137 " datasets[i][1] = '<=26'\n",
138 "\n",
139 "# 构建 dataframe 并查看数据\n",
140 "df = pd.DataFrame(datasets, columns=labels)\n",
141 "df\n"
142 ]
143 },
144 {
145 "cell_type": "markdown",
146 "metadata": {},
147 "source": [
148 "## 5.2.2 构建决策树 \n",
149 "\n",
150 "**信息增益**用来衡量样本集合复杂度(不确定性)所减少的程度。 \n",
151 "\n",
152 "**信息熵**用来度量信息量的大小。从信息论的角度来看,对信息的度量等于计算信息不确定性的多少。 "
153 ]
154 },
155 {
156 "cell_type": "markdown",
157 "metadata": {},
158 "source": [
159 "假设有 $K$ 个信息,其组成了集合样本 $D$ ,记第 $k$ 个信息发生的概率为$P_k(1≤k≤K)$。 \n",
160 "这 $K$ 个信息的信息熵: \n",
161 "$$E(D)=-\\sum_{k=1}^{K}p_k log_{2} p_k$$\n",
162 "\n",
163 "需要指出:**所有 $p_k$ 累加起来的和为1**。"
164 ]
165 },
166 {
167 "cell_type": "code",
168 "execution_count": null,
169 "metadata": {},
170 "outputs": [],
171 "source": [
172 "def calc_entropy(total_num, count_dict):\n",
173 " \"\"\"\n",
174 " 计算信息熵\n",
175 " :param total_num: 总样本数\n",
176 " :param count_dict: 每类样本及其对应数目的字典\n",
177 " :return: 信息熵\n",
178 " \"\"\"\n",
179 " # 使用信息熵公式计算\n",
180 " ent = -sum([(p / total_num) * log(p / total_num, 2) for p in count_dict.values() if p != 0])\n",
181 " # 避免 print 显示异常\n",
182 " if ent == 0:\n",
183 " ent = 0\n",
184 " # 返回信息熵,精确到小数点后 4 位\n",
185 " return round(ent, 4)"
186 ]
187 },
188 {
189 "cell_type": "markdown",
190 "metadata": {},
191 "source": [
192 "\n",
193 "\n",
194 "现在用**熵**来构建决策树。数据中 14 个样本分为 “游客来游乐场( 9 个样本)” 和 “游客不来游乐场( 5 个样本)” 两个类别,即 K = 2。\n",
195 "\n",
196 "记 “游客来游乐场” 和 “游客不来游乐场” 的概率分别为 $p_1$ 和 $p_2$ ,显然 $p_1=\\frac{9}{14}$,$p_1=\\frac{5}{14}$,则这 14 个样本所蕴含的信息熵:\n",
197 "\n",
198 "$$E(D)=-\\sum_{k=1}^{2}p_{k}log_{2}{p_k}=-(\\frac{9}{14}×log_{2}{\\frac{9}{14}}+\\frac{5}{14}×log_{2}{\\frac{5}{14}})=0.940$$"
199 ]
200 },
201 {
202 "cell_type": "markdown",
203 "metadata": {},
204 "source": [
205 "我们可以用下面这种方式对 daraframe 的数据按条件进行筛选。"
206 ]
207 },
208 {
209 "cell_type": "code",
210 "execution_count": null,
211 "metadata": {},
212 "outputs": [],
213 "source": [
214 "# 例如:按 是否前往游乐场==0 进行筛选\n",
215 "df[df['是否前往游乐场']=='0']"
216 ]
217 },
218 {
219 "cell_type": "markdown",
220 "metadata": {},
221 "source": [
222 "使用上面的方法,可以得到计算信息熵所需的总样本数,以及每类样本及其对应数目的字典,然后计算信息熵。"
223 ]
224 },
225 {
226 "cell_type": "code",
227 "execution_count": null,
228 "metadata": {},
229 "outputs": [],
230 "source": [
231 "# 总样本数\n",
232 "total_num = df.shape[0]\n",
233 "# 每类样本及其对应数目的字典\n",
234 "count_dict = {'前往':df[df['是否前往游乐场']=='1'].shape[0], '不前往':df[df['是否前往游乐场']=='1'].shape[1]}\n",
235 "# 计算信息熵\n",
236 "entropy = calc_entropy(total_num, count_dict)\n",
237 "entropy"
238 ]
239 },
240 {
241 "cell_type": "markdown",
242 "metadata": {},
243 "source": [
244 "**计算天气状况所对应的信息熵**: \n",
245 "天气状况的三个属性记为 $a_0=“晴”$ ,$a_1=“多云”$ ,$a_2=“雨”$ , \n",
246 "属性取值为 $a_i$ 对应分支节点所包含子样本集记为 $D_i$ ,该子样本集包含样本数量记为 $|D_i|$ 。"
247 ]
248 },
249 {
250 "cell_type": "markdown",
251 "metadata": {},
252 "source": [
253 "|天气属性取值$a_i$|“晴”|“多云”|“雨”|\n",
254 "|:--:|:--:|:--:|:--:|\n",
255 "|对应样本数$|D_i|$|5|4|5|\n",
256 "|正负样本数量|(2+,3-)|(4+,0-)|(3+,2-)|\n",
257 "\n",
258 "计算天气状况每个属性值的信息熵:\n",
259 "\n",
260 "$“晴”:E(D_0)=-(\\frac{2}{5}×log_{2}{\\frac{2}{5}}+\\frac{3}{5}×log_{2}{\\frac{3}{5}})=0.971$\n",
261 "\n",
262 "$“多云”:E(D_1)=-(\\frac{4}{4}×log_{2}{\\frac{4}{4}})=0$\n",
263 "\n",
264 "$“雨”:E(D_2)=-(\\frac{3}{5}×log_{2}{\\frac{3}{5}}+\\frac{2}{5}×log_{2}{\\frac{2}{5}})=0.971$"
265 ]
266 },
267 {
268 "cell_type": "markdown",
269 "metadata": {},
270 "source": [
271 "我们可以使用下面的写法,对 dataframe 进行多个条件的筛选。"
272 ]
273 },
274 {
275 "cell_type": "code",
276 "execution_count": null,
277 "metadata": {},
278 "outputs": [],
279 "source": [
280 "# 筛选出 天气为晴并且去游乐场的样本数据\n",
281 "df[(df['天气']=='晴') & (df['是否前往游乐场']=='1')]"
282 ]
283 },
284 {
285 "cell_type": "code",
286 "execution_count": null,
287 "metadata": {},
288 "outputs": [],
289 "source": [
290 "# 天气为晴的总天数\n",
291 "total_num_sun = df[df['天气']=='晴'].shape[0]\n",
292 "# 天气为晴时,去游乐场和不去游乐场的人数\n",
293 "count_dict_sun = {'前往':df[(df['天气']=='晴') & (df['是否前往游乐场']=='1')].shape[0], \n",
294 " '不前往':df[(df['天气']=='晴') & (df['是否前往游乐场']=='0')].shape[0]}\n",
295 "print(count_dict_sun)\n",
296 "# 计算天气-晴 的信息熵\n",
297 "ent_sun = calc_entropy(total_num_sun, count_dict_sun)\n",
298 "print('天气-晴 的信息熵为:%s' % ent_sun)\n"
299 ]
300 },
301 {
302 "cell_type": "code",
303 "execution_count": null,
304 "metadata": {},
305 "outputs": [],
306 "source": [
307 "# 天气为多云的总天数\n",
308 "total_num_cloud = df[df['天气']=='多云'].shape[0]\n",
309 "# 天气为多云时,去游乐场和不去游乐场的人数\n",
310 "count_dict_cloud = {'前往':df[(df['天气']=='多云') & (df['是否前往游乐场']=='1')].shape[0], \n",
311 " '不前往':df[(df['天气']=='多云') & (df['是否前往游乐场']=='0')].shape[0]}\n",
312 "print(count_dict_cloud)\n",
313 "# 计算天气-多云 的信息熵\n",
314 "ent_cloud = calc_entropy(total_num_cloud, count_dict_cloud)\n",
315 "print('天气-多云 的信息熵为:%s' % ent_cloud)"
316 ]
317 },
318 {
319 "cell_type": "code",
320 "execution_count": null,
321 "metadata": {},
322 "outputs": [],
323 "source": [
324 "# 天气为雨的总天数\n",
325 "total_num_rain = df[df['天气']=='雨'].shape[0]\n",
326 "# 天气为雨时,去游乐场和不去游乐场的人数\n",
327 "count_dict_rain = {'前往':df[(df['天气']=='雨') & (df['是否前往游乐场']=='1')].shape[0], \n",
328 " '不前往':df[(df['天气']=='雨') & (df['是否前往游乐场']=='0')].shape[0]}\n",
329 "print(count_dict_rain)\n",
330 "# 计算天气-雨 的信息熵\n",
331 "ent_rain = calc_entropy(total_num_rain, count_dict_rain)\n",
332 "print('天气-雨 的信息熵为:%s' % ent_rain)"
333 ]
334 },
335 {
336 "cell_type": "markdown",
337 "metadata": {},
338 "source": [
339 "计算天气状况的信息增益: \n",
340 "$$Gain(D,A)=E(D)-\\sum_{i}^{n}\\frac{|D_i|}{D}E(D)$$"
341 ]
342 },
343 {
344 "cell_type": "markdown",
345 "metadata": {},
346 "source": [
347 "其中,$A=“天气状况”$。于是天气状况这一气象特点的信息增益为:\n",
348 "$$Gain(D,天气)=0.940-(\\frac{5}{14}×0.971+\\frac{4}{14}×0+\\frac{5}{14}×0.971)=0.246$$\n",
349 "\n",
350 "同理可以计算温度高低、湿度大小、风力强弱三个气象特点的信息增益。 \n",
351 "通常情况下,某个分支的信息增益越大,则该分支对样本集划分所获得的“纯度”越大,信息不确定性减少的程度越大。"
352 ]
353 },
354 {
355 "cell_type": "markdown",
356 "metadata": {},
357 "source": [
358 "假设有 $K$ 个信息,其组成了集合样本 $D$ ,记第 $k$ 个信息发生的概率为$P_k(1≤k≤K)$。 \n",
359 "这 $K$ 个信息的信息熵: \n",
360 "$$E(D)=-\\sum_{k=1}^{K}p_k log_{2} p_k$$\n",
361 "\n",
362 "需要指出:**所有 $p_k$ 累加起来的和为1**。"
363 ]
364 },
365 {
366 "cell_type": "markdown",
367 "metadata": {},
368 "source": [
369 "使用上面的公式计算信息增益。"
370 ]
371 },
372 {
373 "cell_type": "code",
374 "execution_count": null,
375 "metadata": {},
376 "outputs": [],
377 "source": [
378 "# 信息增益\n",
379 "gain = entropy - (total_num_sun/total_num*ent_sun + \n",
380 " total_num_cloud/total_num*ent_cloud + \n",
381 " total_num_rain/total_num*ent_rain)\n",
382 "gain"
383 ]
384 },
385 {
386 "cell_type": "markdown",
387 "metadata": {},
388 "source": [
389 "### 思考与练习 "
390 ]
391 },
392 {
393 "cell_type": "markdown",
394 "metadata": {},
395 "source": [
396 "1. 分别将天气状况、温度高低、湿度大小、风力强弱作为分支点来构建决策树,查看信息增益\n"
397 ]
398 },
399 {
400 "cell_type": "code",
401 "execution_count": 4,
402 "metadata": {},
403 "outputs": [],
404 "source": [
405 "# todo 以温度 26 为温度高低的分界线,计算信息增益\n"
406 ]
407 },
408 {
409 "cell_type": "code",
410 "execution_count": 5,
411 "metadata": {},
412 "outputs": [],
413 "source": [
414 "# todo 以湿度 75 为湿度大小的分界线,计算信息增益\n"
415 ]
416 },
417 {
418 "cell_type": "code",
419 "execution_count": 6,
420 "metadata": {},
421 "outputs": [],
422 "source": [
423 "# todo 以有风无风划分,计算信息增益\n"
424 ]
425 },
426 {
427 "cell_type": "markdown",
428 "metadata": {},
429 "source": [
430 "2. 每朵鸢尾花有萼片长度、萼片宽度、花瓣长度、花瓣宽度四个特征。现在需要根据这四个特征将鸢尾花分为杂色鸢尾花、维吉尼亚鸢尾和山鸢尾三类,试构造决策树进行分类。\n",
431 "\n",
432 "|序号|萼片长度|萼片宽度|花瓣长度|花瓣宽度|种类|\n",
433 "|:--:|:--:|:--:|:--:|:--:|:--:|\n",
434 "|1|5.0|2.0|3.5|1.0|杂色鸢尾|\n",
435 "|2|6.0|2.2|5.0|1.5|维吉尼亚鸢尾|\n",
436 "|3|6.0|2.2|4.0|1.0|杂色鸢尾|\n",
437 "|4|6.2|2.2|4.5|1.5|杂色鸢尾|\n",
438 "|5|4.5|2.3|1.3|0.3|山鸢尾|"
439 ]
440 },
441 {
442 "cell_type": "markdown",
443 "metadata": {},
444 "source": [
445 "观察上表中的五笔数据,我们可以看到 杂色鸢尾 和 维吉尼亚鸢尾 的花瓣宽度明显大于山鸢尾,所以可以通过判断花瓣宽度是否大于 0.7,来将山鸢尾区从其他两种鸢尾中区分出来。\n",
446 "\n",
447 "同时,杂色鸢尾 和 维吉尼亚鸢尾 的花瓣长度明显大于山鸢尾,所以也可以通过判断花瓣长度是否大于 2.4,来将山鸢尾区从其他两种鸢尾中区分出来。\n",
448 "\n",
449 "然后我们观察到 维吉尼亚鸢尾的花瓣长度明显大于杂色鸢尾,所以可以通过判断花瓣长度是否大于 4.75,来将杂色鸢尾和维吉尼亚鸢尾区分出来。"
450 ]
451 },
452 {
453 "cell_type": "markdown",
454 "metadata": {},
455 "source": [
456 "实际上是否如此呢?"
457 ]
458 },
459 {
460 "cell_type": "markdown",
461 "metadata": {},
462 "source": [
463 "上面的表格只是 Iris 数据集的一小部分,完整的数据集包含 150 个数据样本,分为 3 类,每类 50 个数据,每个数据包含 4 个属性。即花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性。\n",
464 "\n",
465 "我们使用 sklearn 工具包来构建决策树模型,先导入数据集。"
466 ]
467 },
468 {
469 "cell_type": "code",
470 "execution_count": null,
471 "metadata": {},
472 "outputs": [],
473 "source": [
474 "from sklearn.datasets import load_iris\n",
475 "# 加载数据集\n",
476 "iris = load_iris()\n",
477 "# 查看 label\n",
478 "print(list(iris.target_names))\n",
479 "# 查看 feature\n",
480 "print(iris.feature_names)"
481 ]
482 },
483 {
484 "cell_type": "markdown",
485 "metadata": {},
486 "source": [
487 "setosa 是山鸢尾,versicolor是杂色鸢尾,virginica是维吉尼亚鸢尾。\n",
488 "\n",
489 "sepal length, sepal width,petal length,petal width 分别是萼片长度,萼片宽度,花瓣长度,花瓣宽度。"
490 ]
491 },
492 {
493 "cell_type": "markdown",
494 "metadata": {},
495 "source": [
496 "然后进行训练集和测试集的切分。"
497 ]
498 },
499 {
500 "cell_type": "code",
501 "execution_count": null,
502 "metadata": {},
503 "outputs": [],
504 "source": [
505 "from sklearn.model_selection import train_test_split\n",
506 "# 载入数据\n",
507 "X, y = load_iris(return_X_y=True)\n",
508 "# 切分训练集合测试集\n",
509 "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)"
510 ]
511 },
512 {
513 "cell_type": "markdown",
514 "metadata": {},
515 "source": [
516 "接下来,我们在训练集数据上训练决策树模型。"
517 ]
518 },
519 {
520 "cell_type": "code",
521 "execution_count": null,
522 "metadata": {},
523 "outputs": [],
524 "source": [
525 "from sklearn import tree\n",
526 "from sklearn.tree import DecisionTreeClassifier\n",
527 "# 初始化模型,可以调整 max_depth 来观察模型的表现\n",
528 "clf = tree.DecisionTreeClassifier(random_state=42, max_depth=2)\n",
529 "# 训练模型\n",
530 "clf = clf.fit(X_train, y_train)"
531 ]
532 },
533 {
534 "cell_type": "markdown",
535 "metadata": {},
536 "source": [
537 "我们可以使用 graphviz 包来展示构建好的决策树。"
538 ]
539 },
540 {
541 "cell_type": "code",
542 "execution_count": null,
543 "metadata": {},
544 "outputs": [],
545 "source": [
546 "pip install graphviz"
547 ]
548 },
549 {
550 "cell_type": "code",
551 "execution_count": null,
552 "metadata": {},
553 "outputs": [],
554 "source": [
555 "import graphviz\n",
556 "feature_names = ['萼片长度','萼片宽度','花瓣长度','花瓣宽度']\n",
557 "target_names = ['山鸢尾', '杂色鸢尾', '维吉尼亚鸢尾']\n",
558 "# 可视化生成的决策树\n",
559 "dot_data = tree.export_graphviz(clf, out_file=None, \n",
560 " feature_names=feature_names, \n",
561 " class_names=target_names, \n",
562 " filled=True, rounded=True, \n",
563 " special_characters=True) \n",
564 "graph = graphviz.Source(dot_data) \n",
565 "graph "
566 ]
567 },
568 {
569 "cell_type": "markdown",
570 "metadata": {},
571 "source": [
572 "我们看模型在测试集上的表现"
573 ]
574 },
575 {
576 "cell_type": "code",
577 "execution_count": null,
578 "metadata": {},
579 "outputs": [],
580 "source": [
581 "from sklearn.metrics import accuracy_score\n",
582 "y_test_predict = clf.predict(X_test)\n",
583 "accuracy_score(y_test,y_test_predict)"
584 ]
585 },
586 {
587 "cell_type": "markdown",
588 "metadata": {},
589 "source": [
590 "### 实践与体验\n",
591 "\n",
592 "**计算文章的信息熵**\n",
593 "\n",
594 "收集中英文对照的短文,在计算短文内中文单词和英文单词出现概率基础上,计算该两篇短文的信息熵,比较中文短文信息熵和英文短文信息熵的大小。"
595 ]
596 },
597 {
598 "cell_type": "markdown",
599 "metadata": {},
600 "source": [
601 "首先定义一个方法来辅助读取文件的内容。"
602 ]
603 },
604 {
605 "cell_type": "code",
606 "execution_count": null,
607 "metadata": {},
608 "outputs": [],
609 "source": [
610 "def read_file(path):\n",
611 " \"\"\"\n",
612 " 读取某文件的内容\n",
613 " :param path: 文件的路径\n",
614 " :return: 文件的内容\n",
615 " \"\"\"\n",
616 " contents = \"\"\n",
617 " with open(path) as f:\n",
618 " # 读取每一行的内容\n",
619 " for line in f.readlines():\n",
620 " contents += line\n",
621 " return contents"
622 ]
623 },
624 {
625 "cell_type": "markdown",
626 "metadata": {},
627 "source": [
628 "使用上面定义的方法读取英文短文及其对应的中文短文。"
629 ]
630 },
631 {
632 "cell_type": "code",
633 "execution_count": null,
634 "metadata": {},
635 "outputs": [],
636 "source": [
637 "# 读取英文短文\n",
638 "en_essay = read_file('essay3_en.txt')\n",
639 "# 读取中文短文\n",
640 "ch_essay = read_file('essay3_ch.txt')\n"
641 ]
642 },
643 {
644 "cell_type": "markdown",
645 "metadata": {},
646 "source": [
647 "处理文本,统计单词出现的概率,并计算信息熵。"
648 ]
649 },
650 {
651 "cell_type": "code",
652 "execution_count": null,
653 "metadata": {},
654 "outputs": [],
655 "source": [
656 "from collections import Counter\n",
657 "import re\n",
658 "\n",
659 "\n",
660 "def cal_essay_entropy(essay, split_by=None):\n",
661 " \"\"\"\n",
662 " 计算文章的信息熵\n",
663 " :param essay: 文章内容\n",
664 " :param split_by: 切分方式,对于中文文章,不需传入,按字符切分,\n",
665 " 对于英文文章,需传入空格字符来进行切分\n",
666 " :return: 文章的信息熵\n",
667 " \"\"\"\n",
668 " # 把英文全部转为小写\n",
669 " essay = essay.lower()\n",
670 " # 去除标点符号\n",
671 " essay = re.sub(\n",
672 " \"[\\f+\\n+\\r+\\t+\\v+\\?\\.\\!\\/_,$%^*(+\\\"\\']+|[+——!,。?、~@#《》¥%……&*()]\", \"\",\n",
673 " essay)\n",
674 " # print(essay)\n",
675 " # 把文本分割为词\n",
676 " if split_by:\n",
677 " word_list = essay.split(split_by)\n",
678 " else:\n",
679 " word_list = list(essay)\n",
680 " # 统计总的单词数\n",
681 " word_number = len(word_list)\n",
682 " print('此文章共有 %s 个单词' % word_number)\n",
683 " # 得到每个单词出现的次数\n",
684 " word_counter = Counter(word_list)\n",
685 " # print('每个单词出现的次数为:%s' % word_counter)\n",
686 " # 使用信息熵公式计算信息熵\n",
687 " ent = -sum([(p / word_number) * log(p / word_number, 2) for p in\n",
688 " word_counter.values()])\n",
689 " print('信息熵为:%.2f' % ent)\n",
690 " return ent"
691 ]
692 },
693 {
694 "cell_type": "code",
695 "execution_count": null,
696 "metadata": {},
697 "outputs": [],
698 "source": [
699 "ent = cal_essay_entropy(ch_essay)"
700 ]
701 },
702 {
703 "cell_type": "code",
704 "execution_count": null,
705 "metadata": {},
706 "outputs": [],
707 "source": [
708 "ent = cal_essay_entropy(en_essay, split_by = ' ')"
709 ]
710 },
711 {
712 "cell_type": "markdown",
713 "metadata": {},
714 "source": [
715 "todo: 你在上面的试验中观察到了什么?请在下面写下你观察到的现象,并尝试分析其原因。"
716 ]
717 },
718 {
719 "cell_type": "markdown",
720 "metadata": {},
721 "source": [
722 "答案:(在此处填写你的答案。)"
723 ]
724 }
725 ],
726 "metadata": {
727 "kernelspec": {
728 "display_name": "Python 3",
729 "language": "python",
730 "name": "python3"
731 },
732 "language_info": {
733 "codemirror_mode": {
734 "name": "ipython",
735 "version": 3
736 },
737 "file_extension": ".py",
738 "mimetype": "text/x-python",
739 "name": "python",
740 "nbconvert_exporter": "python",
741 "pygments_lexer": "ipython3",
742 "version": "3.5.2"
743 }
744 },
745 "nbformat": 4,
746 "nbformat_minor": 2
747 }
0 {
1 "cells": [
2 {
3 "cell_type": "markdown",
4 "metadata": {},
5 "source": [
6 "# 5.2 决策树"
7 ]
8 },
9 {
10 "cell_type": "markdown",
11 "metadata": {},
12 "source": [
13 "决策树是一种通过**树形结构**进行分类的方法。在决策树中,树形结构中每个节点表示对分类目标在属性上的一个判断,每个分支代表基于该属性做出的一个判断,最后树形结构中每个叶子结点代表一种分类结果。"
14 ]
15 },
16 {
17 "cell_type": "markdown",
18 "metadata": {},
19 "source": [
20 "## 5.2.1 决策树分类概念"
21 ]
22 },
23 {
24 "cell_type": "markdown",
25 "metadata": {},
26 "source": [
27 "**数据**: 游乐场经营者提供**天气情况**(如晴、雨、多云)、**温度高低**、**湿度大小**、**风力强弱**等气象特点以及游客当天是否前往游乐场。\n",
28 "\n",
29 "**目标**: 预测游客是否来游乐场游玩。\n",
30 "\n",
31 "\n",
32 "|序号|天气|温度(℃)|湿度|是否有风|是(1)否(0)前往游乐场|\n",
33 "|:--:|:--:|:--:|:--:|:--:|:--:|\n",
34 "|1|晴|29|85|否|0|\n",
35 "|2|晴|26|88|是|0|\n",
36 "|3|多云|28|78|否|1\n",
37 "|4|雨|21|96|否|1|\n",
38 "|5|雨|20|80|否|1|\n",
39 "|6|雨|18|70|是|0|\n",
40 "|7|多云|18|65|是|1|\n",
41 "|8|晴|22|90|否|0|\n",
42 "|9|晴|21|68|否|1|\n",
43 "|10|雨|24|80|否|1|\n",
44 "|11|晴|24|63|是|1|\n",
45 "|12|多云|22|90|是|1|\n",
46 "|13|多云|27|75|否|1|\n",
47 "|14|雨|21|80|是|0|\n",
48 "\n",
49 "根据上表,绘制如图所示的决策树:\n",
50 "\n",
51 "<img src=\"http://imgbed.momodel.cn/决策树11.png\" width=500px>"
52 ]
53 },
54 {
55 "cell_type": "markdown",
56 "metadata": {},
57 "source": [
58 "第一层是天气状况,具有雨、多云和晴三种属性取值。\n",
59 "+ 多云: 样本子集是 { 3, 7, 12, 13 } ,仅有“前往游乐场游玩”一个类别,即肯定去游乐场。 \n",
60 " \n",
61 " \n",
62 "+ 晴: 样本子集是 { 1, 2, 8, 9, 11 }\n",
63 " + 湿度大于 75:样本子集为 { 1, 2, 8 },不前往游乐场。\n",
64 " + 湿度不大于 75:样本子集 { 9, 11 },前往游乐场。\n",
65 " \n",
66 " \n",
67 "+ 雨:样本子集为 { 4, 5, 6, 10, 14 }\n",
68 " + 有风:样本子集 { 6, 14 },不去游乐场。\n",
69 " + 无风:样本子集 { 4, 5, 10 },前往游乐场。\n",
70 " "
71 ]
72 },
73 {
74 "cell_type": "markdown",
75 "metadata": {},
76 "source": [
77 "由上面的例子可以看到,构建决策树的过程就是:\n",
78 "1. 选择一个属性值;\n",
79 "2. 基于该属性对样本集进行划分;\n",
80 "3. 重复步骤 1 和 2 直到最后所得划分结果中每个样本为同一类别。"
81 ]
82 },
83 {
84 "cell_type": "code",
85 "execution_count": 1,
86 "metadata": {},
87 "outputs": [],
88 "source": [
89 "import numpy as np\n",
90 "import pandas as pd\n",
91 "import matplotlib.pyplot as plt\n",
92 "%matplotlib inline\n",
93 "\n",
94 "import math\n",
95 "from math import log\n",
96 "import warnings\n",
97 "warnings.filterwarnings(\"ignore\")"
98 ]
99 },
100 {
101 "cell_type": "code",
102 "execution_count": 2,
103 "metadata": {},
104 "outputs": [
105 {
106 "data": {
107 "text/html": [
108 "<div>\n",
109 "<style scoped>\n",
110 " .dataframe tbody tr th:only-of-type {\n",
111 " vertical-align: middle;\n",
112 " }\n",
113 "\n",
114 " .dataframe tbody tr th {\n",
115 " vertical-align: top;\n",
116 " }\n",
117 "\n",
118 " .dataframe thead th {\n",
119 " text-align: right;\n",
120 " }\n",
121 "</style>\n",
122 "<table border=\"1\" class=\"dataframe\">\n",
123 " <thead>\n",
124 " <tr style=\"text-align: right;\">\n",
125 " <th></th>\n",
126 " <th>天气</th>\n",
127 " <th>温度</th>\n",
128 " <th>湿度</th>\n",
129 " <th>是否有风</th>\n",
130 " <th>是否前往游乐场</th>\n",
131 " </tr>\n",
132 " </thead>\n",
133 " <tbody>\n",
134 " <tr>\n",
135 " <th>0</th>\n",
136 " <td>晴</td>\n",
137 " <td>&gt;26</td>\n",
138 " <td>&gt;75</td>\n",
139 " <td>否</td>\n",
140 " <td>0</td>\n",
141 " </tr>\n",
142 " <tr>\n",
143 " <th>1</th>\n",
144 " <td>晴</td>\n",
145 " <td>&lt;=26</td>\n",
146 " <td>&gt;75</td>\n",
147 " <td>是</td>\n",
148 " <td>0</td>\n",
149 " </tr>\n",
150 " <tr>\n",
151 " <th>2</th>\n",
152 " <td>多云</td>\n",
153 " <td>&gt;26</td>\n",
154 " <td>&gt;75</td>\n",
155 " <td>否</td>\n",
156 " <td>1</td>\n",
157 " </tr>\n",
158 " <tr>\n",
159 " <th>3</th>\n",
160 " <td>雨</td>\n",
161 " <td>&lt;=26</td>\n",
162 " <td>&gt;75</td>\n",
163 " <td>否</td>\n",
164 " <td>1</td>\n",
165 " </tr>\n",
166 " <tr>\n",
167 " <th>4</th>\n",
168 " <td>雨</td>\n",
169 " <td>&lt;=26</td>\n",
170 " <td>&gt;75</td>\n",
171 " <td>否</td>\n",
172 " <td>1</td>\n",
173 " </tr>\n",
174 " <tr>\n",
175 " <th>5</th>\n",
176 " <td>雨</td>\n",
177 " <td>&lt;=26</td>\n",
178 " <td>&lt;=75</td>\n",
179 " <td>是</td>\n",
180 " <td>0</td>\n",
181 " </tr>\n",
182 " <tr>\n",
183 " <th>6</th>\n",
184 " <td>多云</td>\n",
185 " <td>&lt;=26</td>\n",
186 " <td>&lt;=75</td>\n",
187 " <td>是</td>\n",
188 " <td>1</td>\n",
189 " </tr>\n",
190 " <tr>\n",
191 " <th>7</th>\n",
192 " <td>晴</td>\n",
193 " <td>&lt;=26</td>\n",
194 " <td>&gt;75</td>\n",
195 " <td>否</td>\n",
196 " <td>0</td>\n",
197 " </tr>\n",
198 " <tr>\n",
199 " <th>8</th>\n",
200 " <td>晴</td>\n",
201 " <td>&lt;=26</td>\n",
202 " <td>&lt;=75</td>\n",
203 " <td>否</td>\n",
204 " <td>1</td>\n",
205 " </tr>\n",
206 " <tr>\n",
207 " <th>9</th>\n",
208 " <td>雨</td>\n",
209 " <td>&lt;=26</td>\n",
210 " <td>&gt;75</td>\n",
211 " <td>否</td>\n",
212 " <td>1</td>\n",
213 " </tr>\n",
214 " <tr>\n",
215 " <th>10</th>\n",
216 " <td>晴</td>\n",
217 " <td>&lt;=26</td>\n",
218 " <td>&lt;=75</td>\n",
219 " <td>是</td>\n",
220 " <td>1</td>\n",
221 " </tr>\n",
222 " <tr>\n",
223 " <th>11</th>\n",
224 " <td>多云</td>\n",
225 " <td>&lt;=26</td>\n",
226 " <td>&gt;75</td>\n",
227 " <td>是</td>\n",
228 " <td>1</td>\n",
229 " </tr>\n",
230 " <tr>\n",
231 " <th>12</th>\n",
232 " <td>多云</td>\n",
233 " <td>&gt;26</td>\n",
234 " <td>&lt;=75</td>\n",
235 " <td>否</td>\n",
236 " <td>1</td>\n",
237 " </tr>\n",
238 " <tr>\n",
239 " <th>13</th>\n",
240 " <td>雨</td>\n",
241 " <td>&lt;=26</td>\n",
242 " <td>&gt;75</td>\n",
243 " <td>是</td>\n",
244 " <td>0</td>\n",
245 " </tr>\n",
246 " </tbody>\n",
247 "</table>\n",
248 "</div>"
249 ],
250 "text/plain": [
251 " 天气 温度 湿度 是否有风 是否前往游乐场\n",
252 "0 晴 >26 >75 否 0\n",
253 "1 晴 <=26 >75 是 0\n",
254 "2 多云 >26 >75 否 1\n",
255 "3 雨 <=26 >75 否 1\n",
256 "4 雨 <=26 >75 否 1\n",
257 "5 雨 <=26 <=75 是 0\n",
258 "6 多云 <=26 <=75 是 1\n",
259 "7 晴 <=26 >75 否 0\n",
260 "8 晴 <=26 <=75 否 1\n",
261 "9 雨 <=26 >75 否 1\n",
262 "10 晴 <=26 <=75 是 1\n",
263 "11 多云 <=26 >75 是 1\n",
264 "12 多云 >26 <=75 否 1\n",
265 "13 雨 <=26 >75 是 0"
266 ]
267 },
268 "execution_count": 2,
269 "metadata": {},
270 "output_type": "execute_result"
271 }
272 ],
273 "source": [
274 "# 原始数据\n",
275 "datasets = [\n",
276 " ['晴',29,85,'否','0'],\n",
277 " ['晴',26,88,'是','0'],\n",
278 " ['多云',28,78,'否','1'],\n",
279 " ['雨',21,96,'否','1'],\n",
280 " ['雨',20,80,'否','1'],\n",
281 " ['雨',18,70,'是','0'],\n",
282 " ['多云',18,65,'是','1'],\n",
283 " ['晴',22,90,'否','0'],\n",
284 " ['晴',21,68,'否','1'],\n",
285 " ['雨',24,80,'否','1'],\n",
286 " ['晴',24,63,'是','1'],\n",
287 " ['多云',22,90,'是','1'],\n",
288 " ['多云',27,75,'否','1'],\n",
289 " ['雨',21,80,'是','0']\n",
290 "]\n",
291 "\n",
292 "# 数据的列名\n",
293 "labels = ['天气','温度','湿度','是否有风','是否前往游乐场']\n",
294 "\n",
295 "# 将湿度大小分为大于 75 和小于等于 75 这两个属性值,\n",
296 "# 将温度大小分为大于 26 和小于等于 26 这两个属性值\n",
297 "for i in range(len(datasets)):\n",
298 " if datasets[i][2] > 75:\n",
299 " datasets[i][2] = '>75'\n",
300 " else:\n",
301 " datasets[i][2] = '<=75'\n",
302 " if datasets[i][1] > 26:\n",
303 " datasets[i][1] = '>26'\n",
304 " else:\n",
305 " datasets[i][1] = '<=26'\n",
306 "\n",
307 "# 构建 dataframe 并查看数据\n",
308 "df = pd.DataFrame(datasets, columns=labels)\n",
309 "df\n"
310 ]
311 },
312 {
313 "cell_type": "markdown",
314 "metadata": {},
315 "source": [
316 "## 5.2.2 构建决策树 \n",
317 "\n",
318 "**信息增益**用来衡量样本集合复杂度(不确定性)所减少的程度。 \n",
319 "\n",
320 "**信息熵**用来度量信息量的大小。从信息论的角度来看,对信息的度量等于计算信息不确定性的多少。 "
321 ]
322 },
323 {
324 "cell_type": "markdown",
325 "metadata": {},
326 "source": [
327 "假设有 $K$ 个信息,其组成了集合样本 $D$ ,记第 $k$ 个信息发生的概率为$P_k(1≤k≤K)$。 \n",
328 "这 $K$ 个信息的信息熵: \n",
329 "$$E(D)=-\\sum_{k=1}^{K}p_k log_{2} p_k$$\n",
330 "\n",
331 "需要指出:**所有 $p_k$ 累加起来的和为1**。"
332 ]
333 },
334 {
335 "cell_type": "code",
336 "execution_count": 3,
337 "metadata": {},
338 "outputs": [],
339 "source": [
340 "def calc_entropy(total_num, count_dict):\n",
341 " \"\"\"\n",
342 " 计算信息熵\n",
343 " :param total_num: 总样本数\n",
344 " :param count_dict: 每类样本及其对应数目的字典\n",
345 " :return: 信息熵\n",
346 " \"\"\"\n",
347 " # 使用信息熵公式计算\n",
348 " ent = -sum([(p / total_num) * log(p / total_num, 2) for p in count_dict.values() if p != 0])\n",
349 " # 避免 print 显示异常\n",
350 " if ent == 0:\n",
351 " ent = 0\n",
352 " # 返回信息熵,精确到小数点后 4 位\n",
353 " return round(ent, 4)"
354 ]
355 },
356 {
357 "cell_type": "markdown",
358 "metadata": {},
359 "source": [
360 "\n",
361 "\n",
362 "现在用**熵**来构建决策树。数据中 14 个样本分为 “游客来游乐场( 9 个样本)” 和 “游客不来游乐场( 5 个样本)” 两个类别,即 K = 2。\n",
363 "\n",
364 "记 “游客来游乐场” 和 “游客不来游乐场” 的概率分别为 $p_1$ 和 $p_2$ ,显然 $p_1=\\frac{9}{14}$,$p_1=\\frac{5}{14}$,则这 14 个样本所蕴含的信息熵:\n",
365 "\n",
366 "$$E(D)=-\\sum_{k=1}^{2}p_{k}log_{2}{p_k}=-(\\frac{9}{14}×log_{2}{\\frac{9}{14}}+\\frac{5}{14}×log_{2}{\\frac{5}{14}})=0.940$$"
367 ]
368 },
369 {
370 "cell_type": "markdown",
371 "metadata": {},
372 "source": [
373 "我们可以用下面这种方式对 daraframe 的数据按条件进行筛选。"
374 ]
375 },
376 {
377 "cell_type": "code",
378 "execution_count": 4,
379 "metadata": {},
380 "outputs": [
381 {
382 "data": {
383 "text/html": [
384 "<div>\n",
385 "<style scoped>\n",
386 " .dataframe tbody tr th:only-of-type {\n",
387 " vertical-align: middle;\n",
388 " }\n",
389 "\n",
390 " .dataframe tbody tr th {\n",
391 " vertical-align: top;\n",
392 " }\n",
393 "\n",
394 " .dataframe thead th {\n",
395 " text-align: right;\n",
396 " }\n",
397 "</style>\n",
398 "<table border=\"1\" class=\"dataframe\">\n",
399 " <thead>\n",
400 " <tr style=\"text-align: right;\">\n",
401 " <th></th>\n",
402 " <th>天气</th>\n",
403 " <th>温度</th>\n",
404 " <th>湿度</th>\n",
405 " <th>是否有风</th>\n",
406 " <th>是否前往游乐场</th>\n",
407 " </tr>\n",
408 " </thead>\n",
409 " <tbody>\n",
410 " <tr>\n",
411 " <th>0</th>\n",
412 " <td>晴</td>\n",
413 " <td>&gt;26</td>\n",
414 " <td>&gt;75</td>\n",
415 " <td>否</td>\n",
416 " <td>0</td>\n",
417 " </tr>\n",
418 " <tr>\n",
419 " <th>1</th>\n",
420 " <td>晴</td>\n",
421 " <td>&lt;=26</td>\n",
422 " <td>&gt;75</td>\n",
423 " <td>是</td>\n",
424 " <td>0</td>\n",
425 " </tr>\n",
426 " <tr>\n",
427 " <th>5</th>\n",
428 " <td>雨</td>\n",
429 " <td>&lt;=26</td>\n",
430 " <td>&lt;=75</td>\n",
431 " <td>是</td>\n",
432 " <td>0</td>\n",
433 " </tr>\n",
434 " <tr>\n",
435 " <th>7</th>\n",
436 " <td>晴</td>\n",
437 " <td>&lt;=26</td>\n",
438 " <td>&gt;75</td>\n",
439 " <td>否</td>\n",
440 " <td>0</td>\n",
441 " </tr>\n",
442 " <tr>\n",
443 " <th>13</th>\n",
444 " <td>雨</td>\n",
445 " <td>&lt;=26</td>\n",
446 " <td>&gt;75</td>\n",
447 " <td>是</td>\n",
448 " <td>0</td>\n",
449 " </tr>\n",
450 " </tbody>\n",
451 "</table>\n",
452 "</div>"
453 ],
454 "text/plain": [
455 " 天气 温度 湿度 是否有风 是否前往游乐场\n",
456 "0 晴 >26 >75 否 0\n",
457 "1 晴 <=26 >75 是 0\n",
458 "5 雨 <=26 <=75 是 0\n",
459 "7 晴 <=26 >75 否 0\n",
460 "13 雨 <=26 >75 是 0"
461 ]
462 },
463 "execution_count": 4,
464 "metadata": {},
465 "output_type": "execute_result"
466 }
467 ],
468 "source": [
469 "# 例如:按 是否前往游乐场==0 进行筛选\n",
470 "df[df['是否前往游乐场']=='0']"
471 ]
472 },
473 {
474 "cell_type": "markdown",
475 "metadata": {},
476 "source": [
477 "使用上面的方法,可以得到计算信息熵所需的总样本数,以及每类样本及其对应数目的字典,然后计算信息熵。"
478 ]
479 },
480 {
481 "cell_type": "code",
482 "execution_count": 5,
483 "metadata": {},
484 "outputs": [
485 {
486 "data": {
487 "text/plain": [
488 "0.9403"
489 ]
490 },
491 "execution_count": 5,
492 "metadata": {},
493 "output_type": "execute_result"
494 }
495 ],
496 "source": [
497 "# 总样本数\n",
498 "total_num = df.shape[0]\n",
499 "# 每类样本及其对应数目的字典\n",
500 "count_dict = {'前往':df[df['是否前往游乐场']=='1'].shape[0], '不前往':df[df['是否前往游乐场']=='1'].shape[1]}\n",
501 "# 计算信息熵\n",
502 "entropy = calc_entropy(total_num, count_dict)\n",
503 "entropy"
504 ]
505 },
506 {
507 "cell_type": "markdown",
508 "metadata": {},
509 "source": [
510 "**计算天气状况所对应的信息熵**: \n",
511 "天气状况的三个属性记为 $a_0=“晴”$ ,$a_1=“多云”$ ,$a_2=“雨”$ , \n",
512 "属性取值为 $a_i$ 对应分支节点所包含子样本集记为 $D_i$ ,该子样本集包含样本数量记为 $|D_i|$ 。"
513 ]
514 },
515 {
516 "cell_type": "markdown",
517 "metadata": {},
518 "source": [
519 "|天气属性取值$a_i$|“晴”|“多云”|“雨”|\n",
520 "|:--:|:--:|:--:|:--:|\n",
521 "|对应样本数$|D_i|$|5|4|5|\n",
522 "|正负样本数量|(2+,3-)|(4+,0-)|(3+,2-)|\n",
523 "\n",
524 "计算天气状况每个属性值的信息熵:\n",
525 "\n",
526 "$“晴”:E(D_0)=-(\\frac{2}{5}×log_{2}{\\frac{2}{5}}+\\frac{3}{5}×log_{2}{\\frac{3}{5}})=0.971$\n",
527 "\n",
528 "$“多云”:E(D_1)=-(\\frac{4}{4}×log_{2}{\\frac{4}{4}})=0$\n",
529 "\n",
530 "$“雨”:E(D_2)=-(\\frac{3}{5}×log_{2}{\\frac{3}{5}}+\\frac{2}{5}×log_{2}{\\frac{2}{5}})=0.971$"
531 ]
532 },
533 {
534 "cell_type": "markdown",
535 "metadata": {},
536 "source": [
537 "我们可以使用下面的写法,对 dataframe 进行多个条件的筛选。"
538 ]
539 },
540 {
541 "cell_type": "code",
542 "execution_count": 6,
543 "metadata": {},
544 "outputs": [
545 {
546 "data": {
547 "text/html": [
548 "<div>\n",
549 "<style scoped>\n",
550 " .dataframe tbody tr th:only-of-type {\n",
551 " vertical-align: middle;\n",
552 " }\n",
553 "\n",
554 " .dataframe tbody tr th {\n",
555 " vertical-align: top;\n",
556 " }\n",
557 "\n",
558 " .dataframe thead th {\n",
559 " text-align: right;\n",
560 " }\n",
561 "</style>\n",
562 "<table border=\"1\" class=\"dataframe\">\n",
563 " <thead>\n",
564 " <tr style=\"text-align: right;\">\n",
565 " <th></th>\n",
566 " <th>天气</th>\n",
567 " <th>温度</th>\n",
568 " <th>湿度</th>\n",
569 " <th>是否有风</th>\n",
570 " <th>是否前往游乐场</th>\n",
571 " </tr>\n",
572 " </thead>\n",
573 " <tbody>\n",
574 " <tr>\n",
575 " <th>8</th>\n",
576 " <td>晴</td>\n",
577 " <td>&lt;=26</td>\n",
578 " <td>&lt;=75</td>\n",
579 " <td>否</td>\n",
580 " <td>1</td>\n",
581 " </tr>\n",
582 " <tr>\n",
583 " <th>10</th>\n",
584 " <td>晴</td>\n",
585 " <td>&lt;=26</td>\n",
586 " <td>&lt;=75</td>\n",
587 " <td>是</td>\n",
588 " <td>1</td>\n",
589 " </tr>\n",
590 " </tbody>\n",
591 "</table>\n",
592 "</div>"
593 ],
594 "text/plain": [
595 " 天气 温度 湿度 是否有风 是否前往游乐场\n",
596 "8 晴 <=26 <=75 否 1\n",
597 "10 晴 <=26 <=75 是 1"
598 ]
599 },
600 "execution_count": 6,
601 "metadata": {},
602 "output_type": "execute_result"
603 }
604 ],
605 "source": [
606 "# 筛选出 天气为晴并且去游乐场的样本数据\n",
607 "df[(df['天气']=='晴') & (df['是否前往游乐场']=='1')]"
608 ]
609 },
610 {
611 "cell_type": "code",
612 "execution_count": 7,
613 "metadata": {},
614 "outputs": [
615 {
616 "name": "stdout",
617 "output_type": "stream",
618 "text": [
619 "{'不前往': 3, '前往': 2}\n",
620 "天气-晴 的信息熵为:0.971\n"
621 ]
622 }
623 ],
624 "source": [
625 "# 天气为晴的总天数\n",
626 "total_num_sun = df[df['天气']=='晴'].shape[0]\n",
627 "# 天气为晴时,去游乐场和不去游乐场的人数\n",
628 "count_dict_sun = {'前往':df[(df['天气']=='晴') & (df['是否前往游乐场']=='1')].shape[0], \n",
629 " '不前往':df[(df['天气']=='晴') & (df['是否前往游乐场']=='0')].shape[0]}\n",
630 "print(count_dict_sun)\n",
631 "# 计算天气-晴 的信息熵\n",
632 "ent_sun = calc_entropy(total_num_sun, count_dict_sun)\n",
633 "print('天气-晴 的信息熵为:%s' % ent_sun)\n"
634 ]
635 },
636 {
637 "cell_type": "code",
638 "execution_count": 8,
639 "metadata": {},
640 "outputs": [
641 {
642 "name": "stdout",
643 "output_type": "stream",
644 "text": [
645 "{'不前往': 0, '前往': 4}\n",
646 "天气-多云 的信息熵为:0\n"
647 ]
648 }
649 ],
650 "source": [
651 "# 天气为多云的总天数\n",
652 "total_num_cloud = df[df['天气']=='多云'].shape[0]\n",
653 "# 天气为多云时,去游乐场和不去游乐场的人数\n",
654 "count_dict_cloud = {'前往':df[(df['天气']=='多云') & (df['是否前往游乐场']=='1')].shape[0], \n",
655 " '不前往':df[(df['天气']=='多云') & (df['是否前往游乐场']=='0')].shape[0]}\n",
656 "print(count_dict_cloud)\n",
657 "# 计算天气-多云 的信息熵\n",
658 "ent_cloud = calc_entropy(total_num_cloud, count_dict_cloud)\n",
659 "print('天气-多云 的信息熵为:%s' % ent_cloud)"
660 ]
661 },
662 {
663 "cell_type": "code",
664 "execution_count": 9,
665 "metadata": {},
666 "outputs": [
667 {
668 "name": "stdout",
669 "output_type": "stream",
670 "text": [
671 "{'不前往': 2, '前往': 3}\n",
672 "天气-雨 的信息熵为:0.971\n"
673 ]
674 }
675 ],
676 "source": [
677 "# 天气为雨的总天数\n",
678 "total_num_rain = df[df['天气']=='雨'].shape[0]\n",
679 "# 天气为雨时,去游乐场和不去游乐场的人数\n",
680 "count_dict_rain = {'前往':df[(df['天气']=='雨') & (df['是否前往游乐场']=='1')].shape[0], \n",
681 " '不前往':df[(df['天气']=='雨') & (df['是否前往游乐场']=='0')].shape[0]}\n",
682 "print(count_dict_rain)\n",
683 "# 计算天气-雨 的信息熵\n",
684 "ent_rain = calc_entropy(total_num_rain, count_dict_rain)\n",
685 "print('天气-雨 的信息熵为:%s' % ent_rain)"
686 ]
687 },
688 {
689 "cell_type": "markdown",
690 "metadata": {},
691 "source": [
692 "计算天气状况的信息增益: \n",
693 "$$Gain(D,A)=E(D)-\\sum_{i}^{n}\\frac{|D_i|}{D}E(D)$$"
694 ]
695 },
696 {
697 "cell_type": "markdown",
698 "metadata": {},
699 "source": [
700 "其中,$A=“天气状况”$。于是天气状况这一气象特点的信息增益为:\n",
701 "$$Gain(D,天气)=0.940-(\\frac{5}{14}×0.971+\\frac{4}{14}×0+\\frac{5}{14}×0.971)=0.246$$\n",
702 "\n",
703 "同理可以计算温度高低、湿度大小、风力强弱三个气象特点的信息增益。 \n",
704 "通常情况下,某个分支的信息增益越大,则该分支对样本集划分所获得的“纯度”越大,信息不确定性减少的程度越大。"
705 ]
706 },
707 {
708 "cell_type": "markdown",
709 "metadata": {},
710 "source": [
711 "假设有 $K$ 个信息,其组成了集合样本 $D$ ,记第 $k$ 个信息发生的概率为$P_k(1≤k≤K)$。 \n",
712 "这 $K$ 个信息的信息熵: \n",
713 "$$E(D)=-\\sum_{k=1}^{K}p_k log_{2} p_k$$\n",
714 "\n",
715 "需要指出:**所有 $p_k$ 累加起来的和为1**。"
716 ]
717 },
718 {
719 "cell_type": "markdown",
720 "metadata": {},
721 "source": [
722 "使用上面的公式计算信息增益。"
723 ]
724 },
725 {
726 "cell_type": "code",
727 "execution_count": 10,
728 "metadata": {},
729 "outputs": [
730 {
731 "data": {
732 "text/plain": [
733 "0.2467285714285714"
734 ]
735 },
736 "execution_count": 10,
737 "metadata": {},
738 "output_type": "execute_result"
739 }
740 ],
741 "source": [
742 "# 信息增益\n",
743 "gain = entropy - (total_num_sun/total_num*ent_sun + \n",
744 " total_num_cloud/total_num*ent_cloud + \n",
745 " total_num_rain/total_num*ent_rain)\n",
746 "gain"
747 ]
748 },
749 {
750 "cell_type": "markdown",
751 "metadata": {},
752 "source": [
753 "### 思考与练习 "
754 ]
755 },
756 {
757 "cell_type": "markdown",
758 "metadata": {},
759 "source": [
760 "1. 分别将天气状况、温度高低、湿度大小、风力强弱作为分支点来构建决策树,查看信息增益\n"
761 ]
762 },
763 {
764 "cell_type": "code",
765 "execution_count": 11,
766 "metadata": {},
767 "outputs": [
768 {
769 "name": "stdout",
770 "output_type": "stream",
771 "text": [
772 "{'不前往': 1, '前往': 2}\n",
773 "温度 >26 的信息熵为:0.9183\n",
774 "{'不前往': 4, '前往': 7}\n",
775 "温度 <=26 的信息熵为:0.9457\n",
776 "按照温度高低进行划分的信息增益为:0.0004714285714285671\n"
777 ]
778 }
779 ],
780 "source": [
781 "# 温度 >26 的信息熵\n",
782 "total_num_temp_high = df[df['温度']=='>26'].shape[0]\n",
783 "count_dict_temp_high = {'前往':df[(df['温度']=='>26') & (df['是否前往游乐场']=='1')].shape[0], \n",
784 " '不前往':df[(df['温度']=='>26') & (df['是否前往游乐场']=='0')].shape[0]}\n",
785 "print(count_dict_temp_high)\n",
786 "ent_temp_high = calc_entropy(total_num_temp_high, count_dict_temp_high)\n",
787 "print('温度 >26 的信息熵为:%s' % ent_temp_high)\n",
788 "\n",
789 "# 温度 <=26 的信息熵\n",
790 "total_num_temp_low = df[df['温度']=='<=26'].shape[0]\n",
791 "count_dict_temp_low = {'前往':df[(df['温度']=='<=26') & (df['是否前往游乐场']=='1')].shape[0], \n",
792 " '不前往':df[(df['温度']=='<=26') & (df['是否前往游乐场']=='0')].shape[0]}\n",
793 "print(count_dict_temp_low)\n",
794 "ent_temp_low = calc_entropy(total_num_temp_low, count_dict_temp_low)\n",
795 "print('温度 <=26 的信息熵为:%s' % ent_temp_low)\n",
796 "\n",
797 "# 如果按照温度高低进行划分,则对应的信息增益为\n",
798 "gain = entropy - (total_num_temp_high/total_num*ent_temp_high + \n",
799 " total_num_temp_low/total_num*ent_temp_low)\n",
800 "print('按照温度高低进行划分的信息增益为:%s' % gain)"
801 ]
802 },
803 {
804 "cell_type": "code",
805 "execution_count": 12,
806 "metadata": {},
807 "outputs": [
808 {
809 "name": "stdout",
810 "output_type": "stream",
811 "text": [
812 "{'不前往': 4, '前往': 5}\n",
813 "湿度 >75 的信息熵为:0.9911\n",
814 "{'不前往': 1, '前往': 4}\n",
815 "湿度 <=75 的信息熵为:0.7219\n",
816 "按照湿度高低进行划分的信息增益为:0.04534285714285702\n"
817 ]
818 }
819 ],
820 "source": [
821 "# 湿度 >75 的信息熵\n",
822 "total_num_hum_high = df[df['湿度']=='>75'].shape[0]\n",
823 "count_dict_hum_high = {'前往':df[(df['湿度']=='>75') & (df['是否前往游乐场']=='1')].shape[0], \n",
824 " '不前往':df[(df['湿度']=='>75') & (df['是否前往游乐场']=='0')].shape[0]}\n",
825 "print(count_dict_hum_high)\n",
826 "ent_hum_high = calc_entropy(total_num_hum_high, count_dict_hum_high)\n",
827 "print('湿度 >75 的信息熵为:%s' % ent_hum_high)\n",
828 "\n",
829 "# 湿度 <=75 的信息熵\n",
830 "total_num_hum_low = df[df['湿度']=='<=75'].shape[0]\n",
831 "count_dict_hum_low = {'前往':df[(df['湿度']=='<=75') & (df['是否前往游乐场']=='1')].shape[0], \n",
832 " '不前往':df[(df['湿度']=='<=75') & (df['是否前往游乐场']=='0')].shape[0]}\n",
833 "print(count_dict_hum_low)\n",
834 "ent_hum_low = calc_entropy(total_num_hum_low, count_dict_hum_low)\n",
835 "print('湿度 <=75 的信息熵为:%s' % ent_hum_low)\n",
836 "\n",
837 "# 如果按照湿度高低进行划分,则对应的信息增益为\n",
838 "gain = entropy - (total_num_hum_high/total_num*ent_hum_high + \n",
839 " total_num_hum_low/total_num*ent_hum_low)\n",
840 "print('按照湿度高低进行划分的信息增益为:%s' % gain)"
841 ]
842 },
843 {
844 "cell_type": "code",
845 "execution_count": 13,
846 "metadata": {},
847 "outputs": [
848 {
849 "name": "stdout",
850 "output_type": "stream",
851 "text": [
852 "{'不前往': 3, '前往': 3}\n",
853 "有风 的信息熵为:1.0\n",
854 "{'不前往': 2, '前往': 6}\n",
855 "无风 的信息熵为:0.8113\n",
856 "按照是否有风进行划分的信息增益为:0.04812857142857141\n"
857 ]
858 }
859 ],
860 "source": [
861 "# 有风 的信息熵\n",
862 "total_num_wind = df[df['是否有风']=='是'].shape[0]\n",
863 "count_dict_wind = {'前往':df[(df['是否有风']=='是') & (df['是否前往游乐场']=='1')].shape[0], \n",
864 " '不前往':df[(df['是否有风']=='是') & (df['是否前往游乐场']=='0')].shape[0]}\n",
865 "print(count_dict_wind)\n",
866 "ent_wind = calc_entropy(total_num_wind, count_dict_wind)\n",
867 "print('有风 的信息熵为:%s' % ent_wind)\n",
868 "\n",
869 "# 无风 的信息熵\n",
870 "total_num_nowind = df[df['是否有风']=='否'].shape[0]\n",
871 "count_dict_nowind = {'前往':df[(df['是否有风']=='否') & (df['是否前往游乐场']=='1')].shape[0], \n",
872 " '不前往':df[(df['是否有风']=='否') & (df['是否前往游乐场']=='0')].shape[0]}\n",
873 "print(count_dict_nowind)\n",
874 "ent_nowind = calc_entropy(total_num_nowind, count_dict_nowind)\n",
875 "print('无风 的信息熵为:%s' % ent_nowind)\n",
876 "\n",
877 "# 如果按照是否有风进行划分,则对应的信息增益为\n",
878 "gain = entropy - (total_num_wind/total_num*ent_wind + \n",
879 " total_num_nowind/total_num*ent_nowind)\n",
880 "print('按照是否有风进行划分的信息增益为:%s' % gain)"
881 ]
882 },
883 {
884 "cell_type": "markdown",
885 "metadata": {},
886 "source": [
887 "2. 每朵鸢尾花有萼片长度、萼片宽度、花瓣长度、花瓣宽度四个特征。现在需要根据这四个特征将鸢尾花分为杂色鸢尾花、维吉尼亚鸢尾和山鸢尾三类,试构造决策树进行分类。\n",
888 "\n",
889 "|序号|萼片长度|萼片宽度|花瓣长度|花瓣宽度|种类|\n",
890 "|:--:|:--:|:--:|:--:|:--:|:--:|\n",
891 "|1|5.0|2.0|3.5|1.0|杂色鸢尾|\n",
892 "|2|6.0|2.2|5.0|1.5|维吉尼亚鸢尾|\n",
893 "|3|6.0|2.2|4.0|1.0|杂色鸢尾|\n",
894 "|4|6.2|2.2|4.5|1.5|杂色鸢尾|\n",
895 "|5|4.5|2.3|1.3|0.3|山鸢尾|"
896 ]
897 },
898 {
899 "cell_type": "markdown",
900 "metadata": {},
901 "source": [
902 "观察上表中的五笔数据,我们可以看到 杂色鸢尾 和 维吉尼亚鸢尾 的花瓣宽度明显大于山鸢尾,所以可以通过判断花瓣宽度是否大于 0.7,来将山鸢尾区从其他两种鸢尾中区分出来。\n",
903 "\n",
904 "同时,杂色鸢尾 和 维吉尼亚鸢尾 的花瓣长度明显大于山鸢尾,所以也可以通过判断花瓣长度是否大于 2.4,来将山鸢尾区从其他两种鸢尾中区分出来。\n",
905 "\n",
906 "然后我们观察到 维吉尼亚鸢尾的花瓣长度明显大于杂色鸢尾,所以可以通过判断花瓣长度是否大于 4.75,来将杂色鸢尾和维吉尼亚鸢尾区分出来。"
907 ]
908 },
909 {
910 "cell_type": "markdown",
911 "metadata": {},
912 "source": [
913 "实际上是否如此呢?"
914 ]
915 },
916 {
917 "cell_type": "markdown",
918 "metadata": {},
919 "source": [
920 "上面的表格只是 Iris 数据集的一小部分,完整的数据集包含 150 个数据样本,分为 3 类,每类 50 个数据,每个数据包含 4 个属性。即花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性。\n",
921 "\n",
922 "我们使用 sklearn 工具包来构建决策树模型,先导入数据集。"
923 ]
924 },
925 {
926 "cell_type": "code",
927 "execution_count": 14,
928 "metadata": {},
929 "outputs": [
930 {
931 "name": "stdout",
932 "output_type": "stream",
933 "text": [
934 "['setosa', 'versicolor', 'virginica']\n",
935 "['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']\n"
936 ]
937 }
938 ],
939 "source": [
940 "from sklearn.datasets import load_iris\n",
941 "# 加载数据集\n",
942 "iris = load_iris()\n",
943 "# 查看 label\n",
944 "print(list(iris.target_names))\n",
945 "# 查看 feature\n",
946 "print(iris.feature_names)"
947 ]
948 },
949 {
950 "cell_type": "markdown",
951 "metadata": {},
952 "source": [
953 "setosa 是山鸢尾,versicolor是杂色鸢尾,virginica是维吉尼亚鸢尾。\n",
954 "\n",
955 "sepal length, sepal width,petal length,petal width 分别是萼片长度,萼片宽度,花瓣长度,花瓣宽度。"
956 ]
957 },
958 {
959 "cell_type": "markdown",
960 "metadata": {},
961 "source": [
962 "然后进行训练集和测试集的切分。"
963 ]
964 },
965 {
966 "cell_type": "code",
967 "execution_count": 15,
968 "metadata": {},
969 "outputs": [],
970 "source": [
971 "from sklearn.model_selection import train_test_split\n",
972 "# 载入数据\n",
973 "X, y = load_iris(return_X_y=True)\n",
974 "# 切分训练集合测试集\n",
975 "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)"
976 ]
977 },
978 {
979 "cell_type": "markdown",
980 "metadata": {},
981 "source": [
982 "接下来,我们在训练集数据上训练决策树模型。"
983 ]
984 },
985 {
986 "cell_type": "code",
987 "execution_count": 16,
988 "metadata": {},
989 "outputs": [],
990 "source": [
991 "from sklearn import tree\n",
992 "from sklearn.tree import DecisionTreeClassifier\n",
993 "# 初始化模型,可以调整 max_depth 来观察模型的表现\n",
994 "clf = tree.DecisionTreeClassifier(random_state=42, max_depth=2)\n",
995 "# 训练模型\n",
996 "clf = clf.fit(X_train, y_train)"
997 ]
998 },
999 {
1000 "cell_type": "markdown",
1001 "metadata": {},
1002 "source": [
1003 "我们可以使用 graphviz 包来展示构建好的决策树。"
1004 ]
1005 },
1006 {
1007 "cell_type": "code",
1008 "execution_count": 17,
1009 "metadata": {},
1010 "outputs": [
1011 {
1012 "name": "stdout",
1013 "output_type": "stream",
1014 "text": [
1015 "\u001b[33mWARNING: The directory '/home/jovyan/.cache/pip/http' or its parent directory is not owned by the current user and the cache has been disabled. Please check the permissions and owner of that directory. If executing pip with sudo, you may want sudo's -H flag.\u001b[0m\n",
1016 "\u001b[33mWARNING: The directory '/home/jovyan/.cache/pip' or its parent directory is not owned by the current user and caching wheels has been disabled. check the permissions and owner of that directory. If executing pip with sudo, you may want sudo's -H flag.\u001b[0m\n",
1017 "Looking in indexes: https://mirrors.ustc.edu.cn/pypi/web/simple\n",
1018 "Collecting graphviz\n",
1019 " Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f5/74/dbed754c0abd63768d3a7a7b472da35b08ac442cf87d73d5850a6f32391e/graphviz-0.13.2-py2.py3-none-any.whl\n",
1020 "Installing collected packages: graphviz\n",
1021 "Successfully installed graphviz-0.13.2\n",
1022 "\u001b[33mWARNING: You are using pip version 19.1.1, however version 19.3.1 is available.\n",
1023 "You should consider upgrading via the 'pip install --upgrade pip' command.\u001b[0m\n",
1024 "Note: you may need to restart the kernel to use updated packages.\n"
1025 ]
1026 }
1027 ],
1028 "source": [
1029 "pip install graphviz"
1030 ]
1031 },
1032 {
1033 "cell_type": "code",
1034 "execution_count": 18,
1035 "metadata": {},
1036 "outputs": [
1037 {
1038 "data": {
1039 "image/svg+xml": [
1040 "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
1041 "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
1042 " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
1043 "<!-- Generated by graphviz version 2.38.0 (20140413.2041)\n",
1044 " -->\n",
1045 "<!-- Title: Tree Pages: 1 -->\n",
1046 "<svg width=\"353pt\" height=\"314pt\"\n",
1047 " viewBox=\"0.00 0.00 352.50 314.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
1048 "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 310)\">\n",
1049 "<title>Tree</title>\n",
1050 "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-310 348.5,-310 348.5,4 -4,4\"/>\n",
1051 "<!-- 0 -->\n",
1052 "<g id=\"node1\" class=\"node\"><title>0</title>\n",
1053 "<path fill=\"#fcfffd\" stroke=\"black\" d=\"M181.5,-306C181.5,-306 73.5,-306 73.5,-306 67.5,-306 61.5,-300 61.5,-294 61.5,-294 61.5,-235 61.5,-235 61.5,-229 67.5,-223 73.5,-223 73.5,-223 181.5,-223 181.5,-223 187.5,-223 193.5,-229 193.5,-235 193.5,-235 193.5,-294 193.5,-294 193.5,-300 187.5,-306 181.5,-306\"/>\n",
1054 "<text text-anchor=\"start\" x=\"78.5\" y=\"-290.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">花瓣长度 ≤ 2.45</text>\n",
1055 "<text text-anchor=\"start\" x=\"92\" y=\"-275.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gini = 0.666</text>\n",
1056 "<text text-anchor=\"start\" x=\"82.5\" y=\"-260.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 100</text>\n",
1057 "<text text-anchor=\"start\" x=\"69.5\" y=\"-245.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [31, 35, 34]</text>\n",
1058 "<text text-anchor=\"start\" x=\"77\" y=\"-230.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = 杂色鸢尾</text>\n",
1059 "</g>\n",
1060 "<!-- 1 -->\n",
1061 "<g id=\"node2\" class=\"node\"><title>1</title>\n",
1062 "<path fill=\"#e58139\" stroke=\"black\" d=\"M105,-179.5C105,-179.5 12,-179.5 12,-179.5 6,-179.5 -7.10543e-15,-173.5 -7.10543e-15,-167.5 -7.10543e-15,-167.5 -7.10543e-15,-123.5 -7.10543e-15,-123.5 -7.10543e-15,-117.5 6,-111.5 12,-111.5 12,-111.5 105,-111.5 105,-111.5 111,-111.5 117,-117.5 117,-123.5 117,-123.5 117,-167.5 117,-167.5 117,-173.5 111,-179.5 105,-179.5\"/>\n",
1063 "<text text-anchor=\"start\" x=\"30.5\" y=\"-164.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gini = 0.0</text>\n",
1064 "<text text-anchor=\"start\" x=\"17.5\" y=\"-149.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 31</text>\n",
1065 "<text text-anchor=\"start\" x=\"8\" y=\"-134.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [31, 0, 0]</text>\n",
1066 "<text text-anchor=\"start\" x=\"14.5\" y=\"-119.3\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = 山鸢尾</text>\n",
1067 "</g>\n",
1068 "<!-- 0&#45;&gt;1 -->\n",
1069 "<g id=\"edge1\" class=\"edge\"><title>0&#45;&gt;1</title>\n",
1070 "<path fill=\"none\" stroke=\"black\" d=\"M103.561,-222.907C96.9882,-211.763 89.8496,-199.658 83.2334,-188.439\"/>\n",
1071 "<polygon fill=\"black\" stroke=\"black\" points=\"86.1547,-186.503 78.06,-179.667 80.1251,-190.059 86.1547,-186.503\"/>\n",
1072 "<text text-anchor=\"middle\" x=\"71.813\" y=\"-200.174\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">True</text>\n",
1073 "</g>\n",
1074 "<!-- 2 -->\n",
1075 "<g id=\"node3\" class=\"node\"><title>2</title>\n",
1076 "<path fill=\"#f9fefb\" stroke=\"black\" d=\"M248,-187C248,-187 147,-187 147,-187 141,-187 135,-181 135,-175 135,-175 135,-116 135,-116 135,-110 141,-104 147,-104 147,-104 248,-104 248,-104 254,-104 260,-110 260,-116 260,-116 260,-175 260,-175 260,-181 254,-187 248,-187\"/>\n",
1077 "<text text-anchor=\"start\" x=\"148.5\" y=\"-171.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">花瓣宽度 ≤ 1.75</text>\n",
1078 "<text text-anchor=\"start\" x=\"169.5\" y=\"-156.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gini = 0.5</text>\n",
1079 "<text text-anchor=\"start\" x=\"156.5\" y=\"-141.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 69</text>\n",
1080 "<text text-anchor=\"start\" x=\"143\" y=\"-126.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 35, 34]</text>\n",
1081 "<text text-anchor=\"start\" x=\"147\" y=\"-111.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = 杂色鸢尾</text>\n",
1082 "</g>\n",
1083 "<!-- 0&#45;&gt;2 -->\n",
1084 "<g id=\"edge2\" class=\"edge\"><title>0&#45;&gt;2</title>\n",
1085 "<path fill=\"none\" stroke=\"black\" d=\"M151.786,-222.907C157.053,-214.105 162.678,-204.703 168.117,-195.612\"/>\n",
1086 "<polygon fill=\"black\" stroke=\"black\" points=\"171.126,-197.399 173.257,-187.021 165.119,-193.805 171.126,-197.399\"/>\n",
1087 "<text text-anchor=\"middle\" x=\"179.351\" y=\"-207.567\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">False</text>\n",
1088 "</g>\n",
1089 "<!-- 3 -->\n",
1090 "<g id=\"node4\" class=\"node\"><title>3</title>\n",
1091 "<path fill=\"#50e890\" stroke=\"black\" d=\"M170,-68C170,-68 77,-68 77,-68 71,-68 65,-62 65,-56 65,-56 65,-12 65,-12 65,-6 71,-0 77,-0 77,-0 170,-0 170,-0 176,-0 182,-6 182,-12 182,-12 182,-56 182,-56 182,-62 176,-68 170,-68\"/>\n",
1092 "<text text-anchor=\"start\" x=\"88\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gini = 0.188</text>\n",
1093 "<text text-anchor=\"start\" x=\"82.5\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 38</text>\n",
1094 "<text text-anchor=\"start\" x=\"73\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 34, 4]</text>\n",
1095 "<text text-anchor=\"start\" x=\"73\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = 杂色鸢尾</text>\n",
1096 "</g>\n",
1097 "<!-- 2&#45;&gt;3 -->\n",
1098 "<g id=\"edge3\" class=\"edge\"><title>2&#45;&gt;3</title>\n",
1099 "<path fill=\"none\" stroke=\"black\" d=\"M169.945,-103.726C163.966,-94.879 157.635,-85.51 151.634,-76.6303\"/>\n",
1100 "<polygon fill=\"black\" stroke=\"black\" points=\"154.503,-74.6253 146.004,-68.2996 148.703,-78.5448 154.503,-74.6253\"/>\n",
1101 "</g>\n",
1102 "<!-- 4 -->\n",
1103 "<g id=\"node5\" class=\"node\"><title>4</title>\n",
1104 "<path fill=\"#8540e6\" stroke=\"black\" d=\"M332.5,-68C332.5,-68 212.5,-68 212.5,-68 206.5,-68 200.5,-62 200.5,-56 200.5,-56 200.5,-12 200.5,-12 200.5,-6 206.5,-0 212.5,-0 212.5,-0 332.5,-0 332.5,-0 338.5,-0 344.5,-6 344.5,-12 344.5,-12 344.5,-56 344.5,-56 344.5,-62 338.5,-68 332.5,-68\"/>\n",
1105 "<text text-anchor=\"start\" x=\"237\" y=\"-52.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">gini = 0.062</text>\n",
1106 "<text text-anchor=\"start\" x=\"231.5\" y=\"-37.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">samples = 31</text>\n",
1107 "<text text-anchor=\"start\" x=\"222\" y=\"-22.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">value = [0, 1, 30]</text>\n",
1108 "<text text-anchor=\"start\" x=\"208.5\" y=\"-7.8\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">class = 维吉尼亚鸢尾</text>\n",
1109 "</g>\n",
1110 "<!-- 2&#45;&gt;4 -->\n",
1111 "<g id=\"edge4\" class=\"edge\"><title>2&#45;&gt;4</title>\n",
1112 "<path fill=\"none\" stroke=\"black\" d=\"M225.427,-103.726C231.487,-94.879 237.904,-85.51 243.986,-76.6303\"/>\n",
1113 "<polygon fill=\"black\" stroke=\"black\" points=\"246.929,-78.5277 249.692,-68.2996 241.154,-74.5721 246.929,-78.5277\"/>\n",
1114 "</g>\n",
1115 "</g>\n",
1116 "</svg>\n"
1117 ],
1118 "text/plain": [
1119 "<graphviz.files.Source at 0x7ffa4e1cf0f0>"
1120 ]
1121 },
1122 "execution_count": 18,
1123 "metadata": {},
1124 "output_type": "execute_result"
1125 }
1126 ],
1127 "source": [
1128 "import graphviz\n",
1129 "feature_names = ['萼片长度','萼片宽度','花瓣长度','花瓣宽度']\n",
1130 "target_names = ['山鸢尾', '杂色鸢尾', '维吉尼亚鸢尾']\n",
1131 "# 可视化生成的决策树\n",
1132 "dot_data = tree.export_graphviz(clf, out_file=None, \n",
1133 " feature_names=feature_names, \n",
1134 " class_names=target_names, \n",
1135 " filled=True, rounded=True, \n",
1136 " special_characters=True) \n",
1137 "graph = graphviz.Source(dot_data) \n",
1138 "graph "
1139 ]
1140 },
1141 {
1142 "cell_type": "markdown",
1143 "metadata": {},
1144 "source": [
1145 "我们看模型在测试集上的表现"
1146 ]
1147 },
1148 {
1149 "cell_type": "code",
1150 "execution_count": 19,
1151 "metadata": {},
1152 "outputs": [
1153 {
1154 "data": {
1155 "text/plain": [
1156 "0.98"
1157 ]
1158 },
1159 "execution_count": 19,
1160 "metadata": {},
1161 "output_type": "execute_result"
1162 }
1163 ],
1164 "source": [
1165 "from sklearn.metrics import accuracy_score\n",
1166 "y_test_predict = clf.predict(X_test)\n",
1167 "accuracy_score(y_test,y_test_predict)"
1168 ]
1169 },
1170 {
1171 "cell_type": "markdown",
1172 "metadata": {},
1173 "source": [
1174 "### 实践与体验\n",
1175 "\n",
1176 "**计算文章的信息熵**\n",
1177 "\n",
1178 "收集中英文对照的短文,在计算短文内中文单词和英文单词出现概率基础上,计算该两篇短文的信息熵,比较中文短文信息熵和英文短文信息熵的大小。"
1179 ]
1180 },
1181 {
1182 "cell_type": "markdown",
1183 "metadata": {},
1184 "source": [
1185 "首先定义一个方法来辅助读取文件的内容。"
1186 ]
1187 },
1188 {
1189 "cell_type": "code",
1190 "execution_count": 20,
1191 "metadata": {},
1192 "outputs": [],
1193 "source": [
1194 "def read_file(path):\n",
1195 " \"\"\"\n",
1196 " 读取某文件的内容\n",
1197 " :param path: 文件的路径\n",
1198 " :return: 文件的内容\n",
1199 " \"\"\"\n",
1200 " contents = \"\"\n",
1201 " with open(path) as f:\n",
1202 " # 读取每一行的内容\n",
1203 " for line in f.readlines():\n",
1204 " contents += line\n",
1205 " return contents"
1206 ]
1207 },
1208 {
1209 "cell_type": "markdown",
1210 "metadata": {},
1211 "source": [
1212 "使用上面定义的方法读取英文短文及其对应的中文短文。"
1213 ]
1214 },
1215 {
1216 "cell_type": "code",
1217 "execution_count": 21,
1218 "metadata": {},
1219 "outputs": [],
1220 "source": [
1221 "# 读取英文短文\n",
1222 "en_essay = read_file('essay3_en.txt')\n",
1223 "# 读取中文短文\n",
1224 "ch_essay = read_file('essay3_ch.txt')\n"
1225 ]
1226 },
1227 {
1228 "cell_type": "markdown",
1229 "metadata": {},
1230 "source": [
1231 "处理文本,统计单词出现的概率,并计算信息熵。"
1232 ]
1233 },
1234 {
1235 "cell_type": "code",
1236 "execution_count": 22,
1237 "metadata": {},
1238 "outputs": [],
1239 "source": [
1240 "from collections import Counter\n",
1241 "import re\n",
1242 "\n",
1243 "\n",
1244 "def cal_essay_entropy(essay, split_by=None):\n",
1245 " \"\"\"\n",
1246 " 计算文章的信息熵\n",
1247 " :param essay: 文章内容\n",
1248 " :param split_by: 切分方式,对于中文文章,不需传入,按字符切分,\n",
1249 " 对于英文文章,需传入空格字符来进行切分\n",
1250 " :return: 文章的信息熵\n",
1251 " \"\"\"\n",
1252 " # 把英文全部转为小写\n",
1253 " essay = essay.lower()\n",
1254 " # 去除标点符号\n",
1255 " essay = re.sub(\n",
1256 " \"[\\f+\\n+\\r+\\t+\\v+\\?\\.\\!\\/_,$%^*(+\\\"\\']+|[+——!,。?、~@#《》¥%……&*()]\", \"\",\n",
1257 " essay)\n",
1258 " # print(essay)\n",
1259 " # 把文本分割为词\n",
1260 " if split_by:\n",
1261 " word_list = essay.split(split_by)\n",
1262 " else:\n",
1263 " word_list = list(essay)\n",
1264 " # 统计总的单词数\n",
1265 " word_number = len(word_list)\n",
1266 " print('此文章共有 %s 个单词' % word_number)\n",
1267 " # 得到每个单词出现的次数\n",
1268 " word_counter = Counter(word_list)\n",
1269 " # print('每个单词出现的次数为:%s' % word_counter)\n",
1270 " # 使用信息熵公式计算信息熵\n",
1271 " ent = -sum([(p / word_number) * log(p / word_number, 2) for p in\n",
1272 " word_counter.values()])\n",
1273 " print('信息熵为:%.2f' % ent)\n",
1274 " return ent"
1275 ]
1276 },
1277 {
1278 "cell_type": "code",
1279 "execution_count": 23,
1280 "metadata": {},
1281 "outputs": [
1282 {
1283 "name": "stdout",
1284 "output_type": "stream",
1285 "text": [
1286 "此文章共有 1232 个单词\n",
1287 "信息熵为:8.05\n"
1288 ]
1289 }
1290 ],
1291 "source": [
1292 "ent = cal_essay_entropy(ch_essay)"
1293 ]
1294 },
1295 {
1296 "cell_type": "code",
1297 "execution_count": 24,
1298 "metadata": {},
1299 "outputs": [
1300 {
1301 "name": "stdout",
1302 "output_type": "stream",
1303 "text": [
1304 "此文章共有 652 个单词\n",
1305 "信息熵为:7.77\n"
1306 ]
1307 }
1308 ],
1309 "source": [
1310 "ent = cal_essay_entropy(en_essay, split_by = ' ')"
1311 ]
1312 }
1313 ],
1314 "metadata": {
1315 "kernelspec": {
1316 "display_name": "Python 3",
1317 "language": "python",
1318 "name": "python3"
1319 },
1320 "language_info": {
1321 "codemirror_mode": {
1322 "name": "ipython",
1323 "version": 3
1324 },
1325 "file_extension": ".py",
1326 "mimetype": "text/x-python",
1327 "name": "python",
1328 "nbconvert_exporter": "python",
1329 "pygments_lexer": "ipython3",
1330 "version": "3.5.2"
1331 }
1332 },
1333 "nbformat": 4,
1334 "nbformat_minor": 2
1335 }
0 {
1 "cells": [
2 {
3 "cell_type": "markdown",
4 "metadata": {},
5 "source": [
6 "# 5.3 回归分析\n",
7 "\n",
8 "**回归分析**:分析不同变量之间存在关系的研究。 \n",
9 "**回归模型**:刻画不同变量之间关系的模型。"
10 ]
11 },
12 {
13 "cell_type": "markdown",
14 "metadata": {},
15 "source": [
16 "## 5.3.1 回归分析的基本概念\n",
17 "\n",
18 "**数据**:下表给出了莫纳罗亚山从 1970 年到 2005 年间每 5 年的二氧化碳浓度,单位是百万分比浓度(parts per million,简称ppm)\n",
19 "\n",
20 "<table>\n",
21 " <h4 align=\"center\">莫纳罗亚山从 1970 年到 2005 年间每 5 年的二氧化碳浓度</h4>\n",
22 "<tbody>\n",
23 " <tr>\n",
24 " <th align=\"left\">**年份 $x$ ** </th>\n",
25 " <td align=\"center\">1970</td>\n",
26 " <td align=\"center\">1975</td>\n",
27 " <td align=\"center\">1980</td> \n",
28 " <td align=\"center\">1985</td>\n",
29 " <td align=\"center\">1990</td>\n",
30 " <td align=\"center\">1995</td>\n",
31 " <td align=\"center\">2000</td>\n",
32 " <td align=\"center\">2005</td>\n",
33 " </tr>\n",
34 " <tr>\n",
35 " <th align=\"left\">**$CO_2$(ppm) $y$**</th>\n",
36 " <td align=\"center\">325.68</td>\n",
37 " <td align=\"center\">331.15</td>\n",
38 " <td align=\"center\">338.69</td> \n",
39 " <td align=\"center\">345.90</td>\n",
40 " <td align=\"center\">354.19</td>\n",
41 " <td align=\"center\">360.88</td>\n",
42 " <td align=\"center\">369.48</td>\n",
43 " <td align=\"center\">379.67</td>\n",
44 " </tr>\n",
45 "</tbody>\n",
46 "</table>\n",
47 "\n",
48 "\n",
49 "**目标**:分析时间年份和二氧化碳浓度之间的关联关系,由此预测2010年二氧化碳浓度。\n"
50 ]
51 },
52 {
53 "cell_type": "code",
54 "execution_count": null,
55 "metadata": {},
56 "outputs": [],
57 "source": [
58 "import numpy as np\n",
59 "import matplotlib.pyplot as plt\n",
60 "%matplotlib inline\n",
61 "\n",
62 "x = np.array([1970, 1975, 1980, 1985, 1990, 1995, 2000, 2005])\n",
63 "y = np.array([325.68, 331.15, 338.69, 345.90, 354.19, 360.88, 369.48, 379.67])\n",
64 "fig = plt.figure()\n",
65 "plt.xlabel(\"Year\")\n",
66 "plt.ylabel(\"Co2\")\n",
67 "plt.scatter(x, y, c='r')\n",
68 "plt.show()"
69 ]
70 },
71 {
72 "cell_type": "markdown",
73 "metadata": {},
74 "source": [
75 "该地区二氧化碳浓度在逐年缓慢增加,因此我们使用简单的**线性模型**来刻画时间年份和二氧化碳浓度两者之间的关系,即 $二氧化碳浓度 = a × 时间 + b$。 \n",
76 "\n",
77 "设时间年份为 $x$,二氧化碳浓度为 $y$,即 $y = ax + b$ 。\n",
78 "\n",
79 "通过上述数据来确定模型中 $a$ 和 $b$ 的值,一旦求解出 $a$ 和 $b$ 的值,输入任意的时间年份即可估算出该年份对应的二氧化碳浓度值。\n"
80 ]
81 },
82 {
83 "cell_type": "markdown",
84 "metadata": {},
85 "source": [
86 "## 5.3.2 回归分析中参数计算\n",
87 "\n",
88 "最简单的线性回归是**一元线性回归模型**,只包含一个自变量 $x$ 和一个因变量 $y$,并且假定自变量和因变量之间存在 $y=ax+b$ 的线性关系。求解参数 $a$ 和 $b$,需要给定若干组 $(x,y)$ 数据,然后从这些数据出发来计算参数 $a$ 和 $b$。\n"
89 ]
90 },
91 {
92 "cell_type": "markdown",
93 "metadata": {},
94 "source": [
95 "在一元线性回归模型中,最关键的问题是如何计算参数 $a$ 和参数 $b$ 使误差最小化。\n",
96 "\n",
97 "最拟合直线 $y=ax+b$ 应该与这 8 组样本数据点距离都很近,最好的情况是这些样本数据点都在该直线上(不现实),让所有样本数据点离直线尽可能的近(被定义为预测数值和实际数值之间的差)。\n",
98 "\n",
99 "**预测值**:通过给定参数 $a$ 和 $b$ 计算 $ax+b$ 得到的值记为 $\\widetilde{y}=ax+b$\n",
100 "\n",
101 "**真实值**:每组数据中 $(x,y)$ 中对应的 $y$ 值\n",
102 "\n",
103 "**残差**:作为 $x$ 所对应的真实值 $y$ 和模型预测值 $\\widetilde{y}$ 之间误差的绝对值;在实际中一般使用$(y-\\widetilde{y})^2$作为残差。\n",
104 "\n",
105 "回归分析中,对于不同的参数,最佳回归模型是最小化残差平方和的均值,即要求 N 组 $(x,y)$ 数据得到的残差平均值 $\\frac{1}{N}\\sum{(y-\\widetilde{y})^2}$ 最小。\n",
106 "\n",
107 "因此,给定的 8组 $(x,y)$数据,可通过最小二乘法来求解使得残差最小的 $a$ 和 $b$。\n",
108 "\n",
109 "8组 $(x,y)$ 样本数据点记为 $(x_1,y_1)$, $(x_2,y_2)$, ..., $(x_8,y_8)$, 时间年份变量 $x$ 的平均值 $\\overline{x}=\\frac{x_1+x_2+...+x_8}{8}$, 因变量 $y$ 的平均值为$\\overline{y}=\\frac{y_1+y_2+...+y_8}{8}$, 则:\n",
110 "\n",
111 "$a=\\frac{x_1y_1+x_2y_2+...+x_8y_8-8\\overline{x}\\overline{y}}{x_{1}^{2}+x_{2}^{2}+...+x_{8}^{2}-8\\overline{x}^2}=1.5344$\n",
112 "\n",
113 "$b = \\overline{y}-a\\overline{x}=-2698.9$"
114 ]
115 },
116 {
117 "cell_type": "markdown",
118 "metadata": {},
119 "source": [
120 "我们根据上面的公式编写如下的方法来求解 $a$ 和 $b$。"
121 ]
122 },
123 {
124 "cell_type": "code",
125 "execution_count": null,
126 "metadata": {},
127 "outputs": [],
128 "source": [
129 "def cal_a_b(x, y):\n",
130 " \"\"\"\n",
131 " 计算 x 和 y 的线性系数\n",
132 " :param x: np array 格式的自变量\n",
133 " :param y: np array 格式的因变量\n",
134 " :return: 系数 a 和 b\n",
135 " \"\"\"\n",
136 " # 计算 x 和 y 的平均数\n",
137 " x_avarage = np.sum(x) / len(x)\n",
138 " y_avarage = np.sum(y) / len(y)\n",
139 "\n",
140 " # 两个临时变量用于后续计算参数 a 和 b\n",
141 " # m 为 x1*y1 + x2*y2 + ... \n",
142 " # n 为 x1*x1 + x2*x2 + ... \n",
143 " m = np.sum(x * y)\n",
144 " n = np.sum(x ** 2)\n",
145 "\n",
146 " # 计算参数 a 和 b\n",
147 " a = (m - len(x) * x_avarage * y_avarage) / (\n",
148 " n - len(x) * x_avarage * x_avarage)\n",
149 " b = y_avarage - a * x_avarage\n",
150 " return a, b\n",
151 "a, b = cal_a_b(x, y)\n",
152 "print(a, b)"
153 ]
154 },
155 {
156 "cell_type": "markdown",
157 "metadata": {},
158 "source": [
159 "综上:预测莫纳罗亚山地区二氧化碳浓度的一元线性回归模型为:$y=1.5344x-2698.9$。 \n",
160 "我们可以据此绘制出拟合直线。"
161 ]
162 },
163 {
164 "cell_type": "code",
165 "execution_count": null,
166 "metadata": {},
167 "outputs": [],
168 "source": [
169 "# 构造 y = ax + b 直线\n",
170 "x_predict = np.linspace(1965, 2010, 1000)\n",
171 "y_predict = a * x_predict + b\n",
172 "\n",
173 "# 绘图\n",
174 "fig = plt.figure()\n",
175 "plt.xlabel(\"Year\")\n",
176 "plt.ylabel(\"Co2\")\n",
177 "plt.scatter(x, y, c='r')\n",
178 "plt.plot(x_predict, y_predict, c='b')\n",
179 "plt.show()"
180 ]
181 },
182 {
183 "cell_type": "markdown",
184 "metadata": {},
185 "source": [
186 "然后我们可以对该地区1970年之前和2005年之后的二氧化碳浓度进行估算。"
187 ]
188 },
189 {
190 "cell_type": "code",
191 "execution_count": null,
192 "metadata": {},
193 "outputs": [],
194 "source": [
195 "# 例如,预测 2015 年的二氧化碳浓度\n",
196 "a * 2015 + b"
197 ]
198 },
199 {
200 "cell_type": "markdown",
201 "metadata": {},
202 "source": [
203 "最终的预测结果汇总如下: \n",
204 "\n",
205 "<table>\n",
206 "<tbody>\n",
207 " <tr>\n",
208 " <th align=\"left\">**年份 $x$ ** </th>\n",
209 " <td align=\"center\">1960</td>\n",
210 " <td align=\"center\">1965</td>\n",
211 " <td align=\"center\">1970-2005</td> \n",
212 " <td align=\"center\">2010</td>\n",
213 " <td align=\"center\">2015</td>\n",
214 " </tr>\n",
215 " <tr>\n",
216 " <th align=\"left\">**$CO_2$(ppm) $y$**</th>\n",
217 " <td align=\"center\">308.51</td>\n",
218 " <td align=\"center\">316.18</td>\n",
219 " <td align=\"center\">已有数据</td> \n",
220 " <td align=\"center\">385.23</td>\n",
221 " <td align=\"center\">392.90</td>\n",
222 " </tr>\n",
223 "</tbody>\n",
224 "</table>"
225 ]
226 },
227 {
228 "cell_type": "markdown",
229 "metadata": {},
230 "source": [
231 "## 探究莫纳罗亚山地区二氧化碳与温度之间的关系\n",
232 "\n",
233 "该地区 1970 年到 2005 年间每 5 年的二氧化碳浓度以及全球温度(相对于 1961 - 1990 年经过平滑处理的平均温度增长量)\n",
234 "\n",
235 "<table>\n",
236 "<tbody>\n",
237 " <tr>\n",
238 " <th align=\"left\">$CO_2$(ppm) $x$</th>\n",
239 " <td align=\"center\">325.68</td>\n",
240 " <td align=\"center\">331.15</td>\n",
241 " <td align=\"center\">338.69</td> \n",
242 " <td align=\"center\">345.90</td>\n",
243 " <td align=\"center\">354.19</td>\n",
244 " <td align=\"center\">360.88</td>\n",
245 " <td align=\"center\">369.48</td>\n",
246 " <td align=\"center\">379.67</td>\n",
247 " </tr>\n",
248 " <tr>\n",
249 " <th align=\"left\">温度 $y$ </th>\n",
250 " <td align=\"center\">-0.108</td>\n",
251 " <td align=\"center\">-0.082</td>\n",
252 " <td align=\"center\">0.015</td>\n",
253 " <td align=\"center\">0.080</td>\n",
254 " <td align=\"center\">0.149</td>\n",
255 " <td align=\"center\">0.240</td>\n",
256 " <td align=\"center\">0.370</td>\n",
257 " <td align=\"center\">0.420</td>\n",
258 "\n",
259 " </tr>\n",
260 "</tbody>\n",
261 "</table>"
262 ]
263 },
264 {
265 "cell_type": "markdown",
266 "metadata": {},
267 "source": [
268 "我们可以使用上面同样的方法来求解得到参数 $a$ 和 $b$。并绘制出拟合直线。"
269 ]
270 },
271 {
272 "cell_type": "code",
273 "execution_count": null,
274 "metadata": {},
275 "outputs": [],
276 "source": [
277 "# 数据\n",
278 "x = np.array([325.68, 331.15, 338.69, 345.90, 354.19, 360.88, 369.48, 379.67])\n",
279 "y = np.array([-0.108, -0.082, 0.015, 0.080, 0.149, 0.24, 0.370, 0.420])\n",
280 "\n",
281 "# 计算参数 a 和 b\n",
282 "a, b = cal_a_b(x, y)\n",
283 "\n",
284 "# 构造 y = ax + b 直线\n",
285 "x_predict = np.linspace(325, 380, 1000)\n",
286 "y_predict = a * x_predict + b\n",
287 "\n",
288 "# 绘图\n",
289 "fig = plt.figure()\n",
290 "plt.xlabel(\"Co2\")\n",
291 "plt.ylabel(\"Temperature\")\n",
292 "plt.scatter(x, y, c='r')\n",
293 "plt.plot(x_predict, y_predict, c='b')\n",
294 "plt.show()"
295 ]
296 },
297 {
298 "cell_type": "markdown",
299 "metadata": {},
300 "source": [
301 "### 思考与练习"
302 ]
303 },
304 {
305 "cell_type": "markdown",
306 "metadata": {},
307 "source": [
308 "1. 摄氏温度(℃)和华氏温度(℉)是两种计量温度的标准,下表给出了两种温度之间的若干关系,如摄氏温度 0℃ 等于华氏温度 32℉。\n"
309 ]
310 },
311 {
312 "cell_type": "markdown",
313 "metadata": {},
314 "source": [
315 "<table>\n",
316 " <h4 align=\"center\">不同温度下测得摄氏/华氏温度表</h4>\n",
317 "<tbody>\n",
318 " <tr>\n",
319 " <th align=\"left\">摄氏温度(℃) </th>\n",
320 " <td align=\"center\">0</td>\n",
321 " <td align=\"center\">10</td>\n",
322 " <td align=\"center\">15</td> \n",
323 " <td align=\"center\">20</td>\n",
324 " <td align=\"center\">25</td>\n",
325 " <td align=\"center\">30</td>\n",
326 " </tr>\n",
327 " <tr>\n",
328 " <th align=\"left\">华氏温度(℉)</th>\n",
329 " <td align=\"center\">32</td>\n",
330 " <td align=\"center\">50</td>\n",
331 " <td align=\"center\">59</td> \n",
332 " <td align=\"center\">68</td>\n",
333 " <td align=\"center\">77</td>\n",
334 " <td align=\"center\">86</td>\n",
335 " </tr>\n",
336 "</tbody>\n",
337 "</table>"
338 ]
339 },
340 {
341 "cell_type": "markdown",
342 "metadata": {},
343 "source": [
344 "试判断摄氏温度和华氏温度之间是否符合线性关系。如符合,请通过线性回归分析计算出摄氏温度和华氏温度之间的线性回归方程。"
345 ]
346 },
347 {
348 "cell_type": "markdown",
349 "metadata": {},
350 "source": [
351 "首先:我们观察一下摄氏华氏温度的散点图"
352 ]
353 },
354 {
355 "cell_type": "code",
356 "execution_count": null,
357 "metadata": {},
358 "outputs": [],
359 "source": [
360 "# 数据\n",
361 "x = np.array([0, 10, 15, 20, 25, 30])\n",
362 "y = np.array([32, 50, 59, 68, 77, 86])\n",
363 "fig = plt.figure()\n",
364 "plt.xlabel(\"摄氏温度\")\n",
365 "plt.ylabel(\"华氏温度\")\n",
366 "plt.scatter(x, y, c='r')\n",
367 "plt.show()"
368 ]
369 },
370 {
371 "cell_type": "markdown",
372 "metadata": {},
373 "source": [
374 "观察上图,摄氏温度和华氏温度是否符合线性关系? 如果是,使用我们上面写好求解参数的方法来快速求解系数。"
375 ]
376 },
377 {
378 "cell_type": "code",
379 "execution_count": null,
380 "metadata": {},
381 "outputs": [],
382 "source": [
383 "# todo 求解系数\n",
384 "a, b = \n",
385 "print('参数 a 的值为:{:g},参数 b 的值为:{:g}'.format(a, b))"
386 ]
387 },
388 {
389 "cell_type": "code",
390 "execution_count": null,
391 "metadata": {},
392 "outputs": [],
393 "source": [
394 "# 构造 y = ax + b 直线\n",
395 "x_predict = np.linspace(0, 30, 1000)\n",
396 "y_predict = a * x_predict + b\n",
397 "\n",
398 "# 绘图\n",
399 "fig = plt.figure()\n",
400 "plt.xlabel(\"摄氏温度\")\n",
401 "plt.ylabel(\"华氏温度\")\n",
402 "plt.scatter(x, y, c='r')\n",
403 "plt.plot(x_predict, y_predict, c='b')\n",
404 "plt.show()"
405 ]
406 },
407 {
408 "cell_type": "markdown",
409 "metadata": {},
410 "source": [
411 "2. 摩尔定律是由英特尔创始人之一的戈登·摩尔提出,其基本内容为:当价格不变时,集成电路上可容纳的元器件的数目,大约每隔 18-24 个月变会增加一倍,性能也将提升一倍。下表记录了 1971-2004 年英特尔微处理器晶体管数量的增长。需要注意的是,随着单位面积上晶体管体积越来越小,摩尔定律所描述的晶体管增长在不久的将来会面临发展的极限。"
412 ]
413 },
414 {
415 "cell_type": "markdown",
416 "metadata": {},
417 "source": [
418 "|微处理器|推出年份($x$)|晶体管数量($y$)|$z$=log<sub>2</sub>$y$|\n",
419 "|--|--|--|--|\n",
420 "|4004|1971|2300|11.17|\n",
421 "|8008|1972|2500|11.29|\n",
422 "|8080|1974|4500|12.14|\n",
423 "|8086|1978|29000|14.82|\n",
424 "|Intel266|1982|134000|17.03|\n",
425 "|Intel386~processor|1985|275000|18.07|\n",
426 "|Intel486~processor|1989|1200000|20.19|\n",
427 "|Intel Pentium processor|1993|3100000|21.56|\n",
428 "|Intel Pentium Ⅱ processor|1997|7500000|22.84|\n",
429 "|Intel Pentium Ⅲ processor|1999|9500000|23.18|\n",
430 "|Intel Pentium 4 processor|2000|42000000|25.32|\n",
431 "|Intel Itanium processor|2001|25000000|24.58|\n",
432 "|Intel Itanium 2 processor|2003|220000000|27.72|\n",
433 "|Intel Itanium 2 processor(9MB cache)|2004|592000000|29.14|"
434 ]
435 },
436 {
437 "cell_type": "markdown",
438 "metadata": {},
439 "source": [
440 "摩尔定律刻画了晶体管数量与时间之间存在指数关系,可用非线性回归拟合来表示这种关系,非线性回归拟合超出了本教程的内容范围。不过我们可以对晶体管数量取以 2 为底的对数(记为 $z$ ),通过判断 $z$ 与时间 $x$ 之间是否存在线性关系,来验证摩尔定律。如果上述线性关系存在,使用线性回归方法计算之间的最佳拟合直线。"
441 ]
442 },
443 {
444 "cell_type": "code",
445 "execution_count": null,
446 "metadata": {},
447 "outputs": [],
448 "source": [
449 "# 年份\n",
450 "x = np.array(\n",
451 " [1971, 1972, 1974, 1978, 1982, 1985, 1989, 1993, 1997, 1999, 2000, 2001,\n",
452 " 2003, 2004])\n",
453 "# 晶体管取以 2 为底的对数\n",
454 "z = np.array(\n",
455 " [11.17, 11.29, 12.14, 14.82, 17.03, 18.07, 20.19, 21.56, 22.84, 23.18,\n",
456 " 25.32, 24.58, 27.72, 29.14])"
457 ]
458 },
459 {
460 "cell_type": "markdown",
461 "metadata": {},
462 "source": [
463 "我们绘图观察 $x$ 和 $z$ 之间的关系"
464 ]
465 },
466 {
467 "cell_type": "code",
468 "execution_count": null,
469 "metadata": {},
470 "outputs": [],
471 "source": [
472 "fig = plt.figure()\n",
473 "plt.xlabel(\"年份\")\n",
474 "plt.ylabel(\"晶体管取以 2 为底的对数\")\n",
475 "plt.scatter(x, z, c='r')\n",
476 "plt.show()"
477 ]
478 },
479 {
480 "cell_type": "markdown",
481 "metadata": {},
482 "source": [
483 "观察上图,𝑧 与时间 𝑥 之间是否存在线性关系?如果是,我们用上面写好的方法来求解系数。"
484 ]
485 },
486 {
487 "cell_type": "code",
488 "execution_count": null,
489 "metadata": {},
490 "outputs": [],
491 "source": [
492 "# todo 求解系数\n",
493 "a, b = \n",
494 "print('参数 a 的值为:{:g},参数 b 的值为:{:g}'.format(a, b))"
495 ]
496 },
497 {
498 "cell_type": "code",
499 "execution_count": null,
500 "metadata": {},
501 "outputs": [],
502 "source": [
503 "# 构造 y = ax + b 直线\n",
504 "x_predict = np.linspace(1970, 2005, 1000)\n",
505 "z_predict = a * x_predict + b\n",
506 "\n",
507 "# 绘图\n",
508 "fig = plt.figure()\n",
509 "plt.xlabel(\"年份\")\n",
510 "plt.ylabel(\"晶体管取以 2 为底的对数\")\n",
511 "plt.scatter(x, z, c='r')\n",
512 "plt.plot(x_predict, z_predict, c='b')\n",
513 "plt.show()"
514 ]
515 }
516 ],
517 "metadata": {
518 "kernelspec": {
519 "display_name": "Python 3",
520 "language": "python",
521 "name": "python3"
522 },
523 "language_info": {
524 "codemirror_mode": {
525 "name": "ipython",
526 "version": 3
527 },
528 "file_extension": ".py",
529 "mimetype": "text/x-python",
530 "name": "python",
531 "nbconvert_exporter": "python",
532 "pygments_lexer": "ipython3",
533 "version": "3.5.2"
534 }
535 },
536 "nbformat": 4,
537 "nbformat_minor": 2
538 }
0 {
1 "cells": [
2 {
3 "cell_type": "markdown",
4 "metadata": {},
5 "source": [
6 "# 5.3 回归分析\n",
7 "\n",
8 "**回归分析**:分析不同变量之间存在关系的研究。 \n",
9 "**回归模型**:刻画不同变量之间关系的模型。"
10 ]
11 },
12 {
13 "cell_type": "markdown",
14 "metadata": {},
15 "source": [
16 "## 5.3.1 回归分析的基本概念\n",
17 "\n",
18 "**数据**:下表给出了莫纳罗亚山从 1970 年到 2005 年间每 5 年的二氧化碳浓度,单位是百万分比浓度(parts per million,简称ppm)\n",
19 "\n",
20 "<table>\n",
21 " <h4 align=\"center\">莫纳罗亚山从 1970 年到 2005 年间每 5 年的二氧化碳浓度</h4>\n",
22 "<tbody>\n",
23 " <tr>\n",
24 " <th align=\"left\">**年份 $x$ ** </th>\n",
25 " <td align=\"center\">1970</td>\n",
26 " <td align=\"center\">1975</td>\n",
27 " <td align=\"center\">1980</td> \n",
28 " <td align=\"center\">1985</td>\n",
29 " <td align=\"center\">1990</td>\n",
30 " <td align=\"center\">1995</td>\n",
31 " <td align=\"center\">2000</td>\n",
32 " <td align=\"center\">2005</td>\n",
33 " </tr>\n",
34 " <tr>\n",
35 " <th align=\"left\">**$CO_2$(ppm) $y$**</th>\n",
36 " <td align=\"center\">325.68</td>\n",
37 " <td align=\"center\">331.15</td>\n",
38 " <td align=\"center\">338.69</td> \n",
39 " <td align=\"center\">345.90</td>\n",
40 " <td align=\"center\">354.19</td>\n",
41 " <td align=\"center\">360.88</td>\n",
42 " <td align=\"center\">369.48</td>\n",
43 " <td align=\"center\">379.67</td>\n",
44 " </tr>\n",
45 "</tbody>\n",
46 "</table>\n",
47 "\n",
48 "\n",
49 "**目标**:分析时间年份和二氧化碳浓度之间的关联关系,由此预测2010年二氧化碳浓度。\n"
50 ]
51 },
52 {
53 "cell_type": "code",
54 "execution_count": null,
55 "metadata": {},
56 "outputs": [],
57 "source": [
58 "import numpy as np\n",
59 "import matplotlib.pyplot as plt\n",
60 "%matplotlib inline\n",
61 "\n",
62 "x = np.array([1970, 1975, 1980, 1985, 1990, 1995, 2000, 2005])\n",
63 "y = np.array([325.68, 331.15, 338.69, 345.90, 354.19, 360.88, 369.48, 379.67])\n",
64 "fig = plt.figure()\n",
65 "plt.xlabel(\"Year\")\n",
66 "plt.ylabel(\"Co2\")\n",
67 "plt.scatter(x, y, c='r')\n",
68 "plt.show()"
69 ]
70 },
71 {
72 "cell_type": "markdown",
73 "metadata": {},
74 "source": [
75 "该地区二氧化碳浓度在逐年缓慢增加,因此我们使用简单的**线性模型**来刻画时间年份和二氧化碳浓度两者之间的关系,即 $二氧化碳浓度 = a × 时间 + b$。 \n",
76 "\n",
77 "设时间年份为 $x$,二氧化碳浓度为 $y$,即 $y = ax + b$ 。\n",
78 "\n",
79 "通过上述数据来确定模型中 $a$ 和 $b$ 的值,一旦求解出 $a$ 和 $b$ 的值,输入任意的时间年份即可估算出该年份对应的二氧化碳浓度值。\n"
80 ]
81 },
82 {
83 "cell_type": "markdown",
84 "metadata": {},
85 "source": [
86 "## 5.3.2 回归分析中参数计算\n",
87 "\n",
88 "最简单的线性回归是**一元线性回归模型**,只包含一个自变量 $x$ 和一个因变量 $y$,并且假定自变量和因变量之间存在 $y=ax+b$ 的线性关系。求解参数 $a$ 和 $b$,需要给定若干组 $(x,y)$ 数据,然后从这些数据出发来计算参数 $a$ 和 $b$。\n"
89 ]
90 },
91 {
92 "cell_type": "markdown",
93 "metadata": {},
94 "source": [
95 "在一元线性回归模型中,最关键的问题是如何计算参数 $a$ 和参数 $b$ 使误差最小化。\n",
96 "\n",
97 "最拟合直线 $y=ax+b$ 应该与这 8 组样本数据点距离都很近,最好的情况是这些样本数据点都在该直线上(不现实),让所有样本数据点离直线尽可能的近(被定义为预测数值和实际数值之间的差)。\n",
98 "\n",
99 "**预测值**:通过给定参数 $a$ 和 $b$ 计算 $ax+b$ 得到的值记为 $\\widetilde{y}=ax+b$\n",
100 "\n",
101 "**真实值**:每组数据中 $(x,y)$ 中对应的 $y$ 值\n",
102 "\n",
103 "**残差**:作为 $x$ 所对应的真实值 $y$ 和模型预测值 $\\widetilde{y}$ 之间误差的绝对值;在实际中一般使用$(y-\\widetilde{y})^2$作为残差。\n",
104 "\n",
105 "回归分析中,对于不同的参数,最佳回归模型是最小化残差平方和的均值,即要求 N 组 $(x,y)$ 数据得到的残差平均值 $\\frac{1}{N}\\sum{(y-\\widetilde{y})^2}$ 最小。\n",
106 "\n",
107 "因此,给定的 8组 $(x,y)$数据,可通过最小二乘法来求解使得残差最小的 $a$ 和 $b$。\n",
108 "\n",
109 "8组 $(x,y)$ 样本数据点记为 $(x_1,y_1)$, $(x_2,y_2)$, ..., $(x_8,y_8)$, 时间年份变量 $x$ 的平均值 $\\overline{x}=\\frac{x_1+x_2+...+x_8}{8}$, 因变量 $y$ 的平均值为$\\overline{y}=\\frac{y_1+y_2+...+y_8}{8}$, 则:\n",
110 "\n",
111 "$a=\\frac{x_1y_1+x_2y_2+...+x_8y_8-8\\overline{x}\\overline{y}}{x_{1}^{2}+x_{2}^{2}+...+x_{8}^{2}-8\\overline{x}^2}=1.5344$\n",
112 "\n",
113 "$b = \\overline{y}-a\\overline{x}=-2698.9$"
114 ]
115 },
116 {
117 "cell_type": "markdown",
118 "metadata": {},
119 "source": [
120 "我们根据上面的公式编写如下的方法来求解 $a$ 和 $b$。"
121 ]
122 },
123 {
124 "cell_type": "code",
125 "execution_count": null,
126 "metadata": {},
127 "outputs": [],
128 "source": [
129 "def cal_a_b(x, y):\n",
130 " \"\"\"\n",
131 " 计算 x 和 y 的线性系数\n",
132 " :param x: np array 格式的自变量\n",
133 " :param y: np array 格式的因变量\n",
134 " :return: 系数 a 和 b\n",
135 " \"\"\"\n",
136 " # 计算 x 和 y 的平均数\n",
137 " x_avarage = np.sum(x) / len(x)\n",
138 " y_avarage = np.sum(y) / len(y)\n",
139 "\n",
140 " # 两个临时变量用于后续计算参数 a 和 b\n",
141 " # m 为 x1*y1 + x2*y2 + ... \n",
142 " # n 为 x1*x1 + x2*x2 + ... \n",
143 " m = np.sum(x * y)\n",
144 " n = np.sum(x ** 2)\n",
145 "\n",
146 " # 计算参数 a 和 b\n",
147 " a = (m - len(x) * x_avarage * y_avarage) / (\n",
148 " n - len(x) * x_avarage * x_avarage)\n",
149 " b = y_avarage - a * x_avarage\n",
150 " return a, b\n",
151 "a, b = cal_a_b(x, y)\n",
152 "print(a, b)"
153 ]
154 },
155 {
156 "cell_type": "markdown",
157 "metadata": {},
158 "source": [
159 "综上:预测莫纳罗亚山地区二氧化碳浓度的一元线性回归模型为:$y=1.5344x-2698.9$。 \n",
160 "我们可以据此绘制出拟合直线。"
161 ]
162 },
163 {
164 "cell_type": "code",
165 "execution_count": null,
166 "metadata": {},
167 "outputs": [],
168 "source": [
169 "# 构造 y = ax + b 直线\n",
170 "x_predict = np.linspace(1965, 2010, 1000)\n",
171 "y_predict = a * x_predict + b\n",
172 "\n",
173 "# 绘图\n",
174 "fig = plt.figure()\n",
175 "plt.xlabel(\"Year\")\n",
176 "plt.ylabel(\"Co2\")\n",
177 "plt.scatter(x, y, c='r')\n",
178 "plt.plot(x_predict, y_predict, c='b')\n",
179 "plt.show()"
180 ]
181 },
182 {
183 "cell_type": "markdown",
184 "metadata": {},
185 "source": [
186 "然后我们可以对该地区1970年之前和2005年之后的二氧化碳浓度进行估算。"
187 ]
188 },
189 {
190 "cell_type": "code",
191 "execution_count": null,
192 "metadata": {},
193 "outputs": [],
194 "source": [
195 "# 例如,预测 2015 年的二氧化碳浓度\n",
196 "a * 2015 + b"
197 ]
198 },
199 {
200 "cell_type": "markdown",
201 "metadata": {},
202 "source": [
203 "最终的预测结果汇总如下: \n",
204 "\n",
205 "<table>\n",
206 "<tbody>\n",
207 " <tr>\n",
208 " <th align=\"left\">**年份 $x$ ** </th>\n",
209 " <td align=\"center\">1960</td>\n",
210 " <td align=\"center\">1965</td>\n",
211 " <td align=\"center\">1970-2005</td> \n",
212 " <td align=\"center\">2010</td>\n",
213 " <td align=\"center\">2015</td>\n",
214 " </tr>\n",
215 " <tr>\n",
216 " <th align=\"left\">**$CO_2$(ppm) $y$**</th>\n",
217 " <td align=\"center\">308.51</td>\n",
218 " <td align=\"center\">316.18</td>\n",
219 " <td align=\"center\">已有数据</td> \n",
220 " <td align=\"center\">385.23</td>\n",
221 " <td align=\"center\">392.90</td>\n",
222 " </tr>\n",
223 "</tbody>\n",
224 "</table>"
225 ]
226 },
227 {
228 "cell_type": "markdown",
229 "metadata": {},
230 "source": [
231 "## 探究莫纳罗亚山地区二氧化碳与温度之间的关系\n",
232 "\n",
233 "该地区 1970 年到 2005 年间每 5 年的二氧化碳浓度以及全球温度(相对于 1961 - 1990 年经过平滑处理的平均温度增长量)\n",
234 "\n",
235 "<table>\n",
236 "<tbody>\n",
237 " <tr>\n",
238 " <th align=\"left\">$CO_2$(ppm) $x$</th>\n",
239 " <td align=\"center\">325.68</td>\n",
240 " <td align=\"center\">331.15</td>\n",
241 " <td align=\"center\">338.69</td> \n",
242 " <td align=\"center\">345.90</td>\n",
243 " <td align=\"center\">354.19</td>\n",
244 " <td align=\"center\">360.88</td>\n",
245 " <td align=\"center\">369.48</td>\n",
246 " <td align=\"center\">379.67</td>\n",
247 " </tr>\n",
248 " <tr>\n",
249 " <th align=\"left\">温度 $y$ </th>\n",
250 " <td align=\"center\">-0.108</td>\n",
251 " <td align=\"center\">-0.082</td>\n",
252 " <td align=\"center\">0.015</td>\n",
253 " <td align=\"center\">0.080</td>\n",
254 " <td align=\"center\">0.149</td>\n",
255 " <td align=\"center\">0.240</td>\n",
256 " <td align=\"center\">0.370</td>\n",
257 " <td align=\"center\">0.420</td>\n",
258 "\n",
259 " </tr>\n",
260 "</tbody>\n",
261 "</table>"
262 ]
263 },
264 {
265 "cell_type": "markdown",
266 "metadata": {},
267 "source": [
268 "我们可以使用上面同样的方法来求解得到参数 $a$ 和 $b$。并绘制出拟合直线。"
269 ]
270 },
271 {
272 "cell_type": "code",
273 "execution_count": null,
274 "metadata": {},
275 "outputs": [],
276 "source": [
277 "# 数据\n",
278 "x = np.array([325.68, 331.15, 338.69, 345.90, 354.19, 360.88, 369.48, 379.67])\n",
279 "y = np.array([-0.108, -0.082, 0.015, 0.080, 0.149, 0.24, 0.370, 0.420])\n",
280 "\n",
281 "# 计算参数 a 和 b\n",
282 "a, b = cal_a_b(x, y)\n",
283 "\n",
284 "# 构造 y = ax + b 直线\n",
285 "x_predict = np.linspace(325, 380, 1000)\n",
286 "y_predict = a * x_predict + b\n",
287 "\n",
288 "# 绘图\n",
289 "fig = plt.figure()\n",
290 "plt.xlabel(\"Co2\")\n",
291 "plt.ylabel(\"Temperature\")\n",
292 "plt.scatter(x, y, c='r')\n",
293 "plt.plot(x_predict, y_predict, c='b')\n",
294 "plt.show()"
295 ]
296 },
297 {
298 "cell_type": "markdown",
299 "metadata": {},
300 "source": [
301 "### 思考与练习"
302 ]
303 },
304 {
305 "cell_type": "markdown",
306 "metadata": {},
307 "source": [
308 "1. 摄氏温度(℃)和华氏温度(℉)是两种计量温度的标准,下表给出了两种温度之间的若干关系,如摄氏温度 0℃ 等于华氏温度 32℉。\n"
309 ]
310 },
311 {
312 "cell_type": "markdown",
313 "metadata": {},
314 "source": [
315 "<table>\n",
316 " <h4 align=\"center\">不同温度下测得摄氏/华氏温度表</h4>\n",
317 "<tbody>\n",
318 " <tr>\n",
319 " <th align=\"left\">摄氏温度(℃) </th>\n",
320 " <td align=\"center\">0</td>\n",
321 " <td align=\"center\">10</td>\n",
322 " <td align=\"center\">15</td> \n",
323 " <td align=\"center\">20</td>\n",
324 " <td align=\"center\">25</td>\n",
325 " <td align=\"center\">30</td>\n",
326 " </tr>\n",
327 " <tr>\n",
328 " <th align=\"left\">华氏温度(℉)</th>\n",
329 " <td align=\"center\">32</td>\n",
330 " <td align=\"center\">50</td>\n",
331 " <td align=\"center\">59</td> \n",
332 " <td align=\"center\">68</td>\n",
333 " <td align=\"center\">77</td>\n",
334 " <td align=\"center\">86</td>\n",
335 " </tr>\n",
336 "</tbody>\n",
337 "</table>"
338 ]
339 },
340 {
341 "cell_type": "markdown",
342 "metadata": {},
343 "source": [
344 "试判断摄氏温度和华氏温度之间是否符合线性关系。如符合,请通过线性回归分析计算出摄氏温度和华氏温度之间的线性回归方程。"
345 ]
346 },
347 {
348 "cell_type": "markdown",
349 "metadata": {},
350 "source": [
351 "首先:我们观察一下摄氏华氏温度的散点图"
352 ]
353 },
354 {
355 "cell_type": "code",
356 "execution_count": null,
357 "metadata": {},
358 "outputs": [],
359 "source": [
360 "# 数据\n",
361 "x = np.array([0, 10, 15, 20, 25, 30])\n",
362 "y = np.array([32, 50, 59, 68, 77, 86])\n",
363 "fig = plt.figure()\n",
364 "plt.xlabel(\"摄氏温度\")\n",
365 "plt.ylabel(\"华氏温度\")\n",
366 "plt.scatter(x, y, c='r')\n",
367 "plt.show()"
368 ]
369 },
370 {
371 "cell_type": "markdown",
372 "metadata": {},
373 "source": [
374 "通过散点图,我们观察到摄氏温度和华氏温度是符合线性关系的。使用我们上面写好求解参数的方法,即可快速求解。"
375 ]
376 },
377 {
378 "cell_type": "code",
379 "execution_count": null,
380 "metadata": {},
381 "outputs": [],
382 "source": [
383 "a, b = cal_a_b(x, y)\n",
384 "print('参数 a 的值为:{:g},参数 b 的值为:{:g}'.format(a, b))"
385 ]
386 },
387 {
388 "cell_type": "code",
389 "execution_count": null,
390 "metadata": {},
391 "outputs": [],
392 "source": [
393 "# 构造 y = ax + b 直线\n",
394 "x_predict = np.linspace(0, 30, 1000)\n",
395 "y_predict = a * x_predict + b\n",
396 "\n",
397 "# 绘图\n",
398 "fig = plt.figure()\n",
399 "plt.xlabel(\"摄氏温度\")\n",
400 "plt.ylabel(\"华氏温度\")\n",
401 "plt.scatter(x, y, c='r')\n",
402 "plt.plot(x_predict, y_predict, c='b')\n",
403 "plt.show()"
404 ]
405 },
406 {
407 "cell_type": "markdown",
408 "metadata": {},
409 "source": [
410 "2. 摩尔定律是由英特尔创始人之一的戈登·摩尔提出,其基本内容为:当价格不变时,集成电路上可容纳的元器件的数目,大约每隔 18-24 个月变会增加一倍,性能也将提升一倍。下表记录了 1971-2004 年英特尔微处理器晶体管数量的增长。需要注意的是,随着单位面积上晶体管体积越来越小,摩尔定律所描述的晶体管增长在不久的将来会面临发展的极限。"
411 ]
412 },
413 {
414 "cell_type": "markdown",
415 "metadata": {},
416 "source": [
417 "|微处理器|推出年份($x$)|晶体管数量($y$)|$z$=log<sub>2</sub>$y$|\n",
418 "|--|--|--|--|\n",
419 "|4004|1971|2300|11.17|\n",
420 "|8008|1972|2500|11.29|\n",
421 "|8080|1974|4500|12.14|\n",
422 "|8086|1978|29000|14.82|\n",
423 "|Intel266|1982|134000|17.03|\n",
424 "|Intel386~processor|1985|275000|18.07|\n",
425 "|Intel486~processor|1989|1200000|20.19|\n",
426 "|Intel Pentium processor|1993|3100000|21.56|\n",
427 "|Intel Pentium Ⅱ processor|1997|7500000|22.84|\n",
428 "|Intel Pentium Ⅲ processor|1999|9500000|23.18|\n",
429 "|Intel Pentium 4 processor|2000|42000000|25.32|\n",
430 "|Intel Itanium processor|2001|25000000|24.58|\n",
431 "|Intel Itanium 2 processor|2003|220000000|27.72|\n",
432 "|Intel Itanium 2 processor(9MB cache)|2004|592000000|29.14|"
433 ]
434 },
435 {
436 "cell_type": "markdown",
437 "metadata": {},
438 "source": [
439 "摩尔定律刻画了晶体管数量与时间之间存在指数关系,可用非线性回归拟合来表示这种关系,非线性回归拟合超出了本教程的内容范围。不过我们可以对晶体管数量取以 2 为底的对数(记为 $z$ ),通过判断 $z$ 与时间 $x$ 之间是否存在线性关系,来验证摩尔定律。如果上述线性关系存在,使用线性回归方法计算之间的最佳拟合直线。"
440 ]
441 },
442 {
443 "cell_type": "code",
444 "execution_count": null,
445 "metadata": {},
446 "outputs": [],
447 "source": [
448 "# 年份\n",
449 "x = np.array(\n",
450 " [1971, 1972, 1974, 1978, 1982, 1985, 1989, 1993, 1997, 1999, 2000, 2001,\n",
451 " 2003, 2004])\n",
452 "# 晶体管取以 2 为底的对数\n",
453 "z = np.array(\n",
454 " [11.17, 11.29, 12.14, 14.82, 17.03, 18.07, 20.19, 21.56, 22.84, 23.18,\n",
455 " 25.32, 24.58, 27.72, 29.14])"
456 ]
457 },
458 {
459 "cell_type": "markdown",
460 "metadata": {},
461 "source": [
462 "我们绘图观察 $x$ 和 $z$ 之间的关系"
463 ]
464 },
465 {
466 "cell_type": "code",
467 "execution_count": null,
468 "metadata": {},
469 "outputs": [],
470 "source": [
471 "fig = plt.figure()\n",
472 "plt.xlabel(\"年份\")\n",
473 "plt.ylabel(\"晶体管取以 2 为底的对数\")\n",
474 "plt.scatter(x, z, c='r')\n",
475 "plt.show()"
476 ]
477 },
478 {
479 "cell_type": "markdown",
480 "metadata": {},
481 "source": [
482 "我们看到 𝑧 与时间 𝑥 之间是否存在线性关系,我们用上面写好的方法来求解系数。"
483 ]
484 },
485 {
486 "cell_type": "code",
487 "execution_count": null,
488 "metadata": {},
489 "outputs": [],
490 "source": [
491 "a, b = cal_a_b(x, z)\n",
492 "print('参数 a 的值为:{:g},参数 b 的值为:{:g}'.format(a, b))"
493 ]
494 },
495 {
496 "cell_type": "code",
497 "execution_count": null,
498 "metadata": {},
499 "outputs": [],
500 "source": [
501 "# 构造 y = ax + b 直线\n",
502 "x_predict = np.linspace(1970, 2005, 1000)\n",
503 "z_predict = a * x_predict + b\n",
504 "\n",
505 "# 绘图\n",
506 "fig = plt.figure()\n",
507 "plt.xlabel(\"年份\")\n",
508 "plt.ylabel(\"晶体管取以 2 为底的对数\")\n",
509 "plt.scatter(x, z, c='r')\n",
510 "plt.plot(x_predict, z_predict, c='b')\n",
511 "plt.show()"
512 ]
513 }
514 ],
515 "metadata": {
516 "kernelspec": {
517 "display_name": "Python 3",
518 "language": "python",
519 "name": "python3"
520 },
521 "language_info": {
522 "codemirror_mode": {
523 "name": "ipython",
524 "version": 3
525 },
526 "file_extension": ".py",
527 "mimetype": "text/x-python",
528 "name": "python",
529 "nbconvert_exporter": "python",
530 "pygments_lexer": "ipython3",
531 "version": "3.5.2"
532 }
533 },
534 "nbformat": 4,
535 "nbformat_minor": 2
536 }
0 {
1 "cells": [
2 {
3 "cell_type": "markdown",
4 "metadata": {},
5 "source": [
6 "# 5.4 贝叶斯分析\n",
7 "贝叶斯分析是一种根据概率统计知识对数据进行分析的方法,属于统计学分类的范畴。"
8 ]
9 },
10 {
11 "cell_type": "markdown",
12 "metadata": {},
13 "source": [
14 "## 5.4.1 贝叶斯公式\n",
15 "\n",
16 "- **频率学派**:从历史数据中计算某个事件的概率,认为只要采样足够多,则事件发生的频率就可以无限逼近真实概率。\n",
17 "- **贝叶斯学派**:认为某个事件发生的概率不仅与先前这个事件发生的概率相关(称为**先验概率**),也与后期计算该事件概率时所观测的“新近”信息有关(称为**似然概率**)\n",
18 "\n",
19 "贝叶斯概率计算公式表达:\n",
20 "后验概率 = 先验概率 × 似然概率\n",
21 "\n",
22 "**条件概率**:\n",
23 "\n",
24 "$P(A|B)$:表示事件 $B$ 发生的前提下,事件 $A$ 发生的概率\n",
25 "\n",
26 "$$P(A|B)=\\frac{P(A ∩ B)}{P(B)}$$\n",
27 "\n",
28 "$P(B|A)$:表示事件 $A$ 发生的前提下,事件 $B$ 发生的概率\n",
29 "\n",
30 "$$P(B|A)=\\frac{P(A ∩ B)}{P(A)}$$\n",
31 "\n",
32 "由于:\n",
33 "\n",
34 "$$P(B|A){P(A)}=P(A|B){P(B)}={P(A ∩ B)}$$\n",
35 "\n",
36 "可得**贝叶斯公式**:\n",
37 "\n",
38 "$$P(A|B) = \\frac{P(B|A)P(A)}{P(B)}$$\n",
39 "\n",
40 "其中:\n",
41 "- $P(A)$ 是事件 $A$ 发生的先验概率,与事件 $B$ 是否发生无关;\n",
42 "- $P(B|A)$是事件 $A$ 发生前提下,事件 $B$ 发生的概率,也称为**似然概率**; \n",
43 "- $P(B)$ 是事件 $B$ 发生的先验概率,也称为**标准化常量**;\n",
44 "- $P(A|B)$是事件 $B$ 发生前提下,事件 $A$ 发生的概率,也是 $A$ 的后验概率。 "
45 ]
46 },
47 {
48 "cell_type": "markdown",
49 "metadata": {},
50 "source": [
51 "## 5.4.2 贝叶斯推断\n",
52 "贝叶斯推断是一种基于贝叶斯公式进行分析的统计学方法。"
53 ]
54 },
55 {
56 "cell_type": "markdown",
57 "metadata": {},
58 "source": [
59 "小例子:根据邮件中的 “红包” 字样判别该邮件是不是垃圾邮件"
60 ]
61 },
62 {
63 "cell_type": "code",
64 "execution_count": null,
65 "metadata": {},
66 "outputs": [],
67 "source": [
68 "# 广告邮件的数量 \n",
69 "ad_number = 4000\n",
70 "# 正常邮件的数量\n",
71 "normal_number = 6000\n",
72 "\n",
73 "# 所有广告邮件中,出现 “红包” 关键词的邮件的数量\n",
74 "ad_hongbao_number = 1000\n",
75 "# 所有正常邮件中,出现 “红包” 关键词的邮件的数量\n",
76 "normal_hongbao_number = 6\n",
77 "\n",
78 "# 用户收到广告邮件的先验概率为\n",
79 "P_ad = ad_number / (ad_number + normal_number)\n",
80 "print(\"用户收到广告邮件的先验概率为 \" + str(P_ad))\n",
81 "\n",
82 "# 用户收到正常邮件的先验概率为\n",
83 "P_normal = normal_number / (ad_number + normal_number)\n",
84 "print(\"用户收到正常邮件的先验概率为 \" + str(P_normal))\n",
85 "\n",
86 "# 红包出现的概率\n",
87 "P_hongbao = (normal_hongbao_number + ad_hongbao_number) / (\n",
88 " ad_number + normal_number)\n",
89 "print(\"邮件包含红包的先验概率为 \" + str(P_hongbao))\n",
90 "\n",
91 "# 广告邮件中出现 “红包” 关键词的条件概率\n",
92 "P_hongbao_ad = ad_hongbao_number / ad_number\n",
93 "print(\"广告邮件中出现 “红包” 关键词的条件概率为 \" + str(P_hongbao_ad))\n",
94 "\n",
95 "# 正确邮件中出现 “红包” 关键词的条件概率\n",
96 "P_hongbao_normal = normal_hongbao_number / normal_number\n",
97 "print(\"正常邮件中出现 “红包” 关键词的条件概率为 \" + str(P_hongbao_normal))\n",
98 "\n",
99 "# 根据贝叶斯定理可得\n",
100 "# 当邮件中出现 “红包” ,其为广告邮件的后验概率\n",
101 "P_ad_hongbao = P_ad * P_hongbao_ad / P_hongbao\n",
102 "print(\"当邮件中出现 “红包” ,其为广告邮件的后验概率为 \" + str(P_ad_hongbao))\n",
103 "\n",
104 "# 当邮件中出现 “红包” ,其为正常邮件的后验概率\n",
105 "P_normal_hongbao = P_normal * P_hongbao_normal / P_hongbao\n",
106 "print(\"当邮件中出现 “红包” ,其为正常邮件的后验概率为 \" + str(P_normal_hongbao))"
107 ]
108 },
109 {
110 "cell_type": "markdown",
111 "metadata": {},
112 "source": [
113 "## 5.4.3 朴素贝叶斯分类器 \n",
114 "一种常用的分类算法,其假设**样本各个特征之间相互独立、互不影响**。"
115 ]
116 },
117 {
118 "cell_type": "markdown",
119 "metadata": {},
120 "source": [
121 "小例子:预测同学会不会在某店铺订餐。\n",
122 "\n",
123 "**目标**:根据某同学的订单记录,如果向他推荐一家“价位低、口味偏甜、距离远”的店铺,判断他会下单吗?\n",
124 "\n",
125 "**数据**:该同学的下单记录如下\n",
126 "\n",
127 "|店铺价位|店铺口味|店铺距离|是否下单|\n",
128 "|:--:|:--:|:--:|:--:|\n",
129 "|高|偏甜|近|是|\n",
130 "|高|清淡|近|否|\n",
131 "|高|偏辣|远|否|\n",
132 "|高|偏甜|远|否|\n",
133 "|低|偏甜|近|是|\n",
134 "|低|偏甜|近|是|\n",
135 "|低|清淡|远|否|\n",
136 "|低|偏辣|远|是|\n"
137 ]
138 },
139 {
140 "cell_type": "markdown",
141 "metadata": {},
142 "source": [
143 "该同学在收到8次推荐后,下单4次和没有下单4次,则其“下单”,“不下单”的概率: \n",
144 "$$P(下单) = \\frac{4}{8}=0.5$$ \n",
145 "$$P(不下单) = \\frac{4}{8}=0.5$$"
146 ]
147 },
148 {
149 "cell_type": "markdown",
150 "metadata": {},
151 "source": [
152 "该同学对 “价位低、口味偏甜、距离远” 这次推荐的 “下单” 或 “不下单” 的似然概率为(注意基本假设是店铺价位、口味、距离这些特质中间互相独立,互不影响):\n",
153 "\n",
154 "$$\n",
155 "\\begin{align}\n",
156 "&P(价位=低,口味=偏甜,距离=远|下单)\\\\\n",
157 "=&P(价位=低|下单)×P(口味=偏甜|下单)×P(距离=远|下单)\\\\\n",
158 "=&\\frac{3}{4}×\\frac{3}{4}×\\frac{1}{4}\\\\\n",
159 "≈ & 0.141\n",
160 "& \\\\\n",
161 "& \\\\\n",
162 "& P(价位=低,口味=偏甜,距离=远|不下单)\\\\\n",
163 "=&P(价位=低|不下单)×P(口味=偏甜|不下单)×P(距离=远|不下单)\\\\\n",
164 "=&\\frac{1}{4}×\\frac{1}{4}×\\frac{3}{4}\\\\\n",
165 "≈ &0.047\n",
166 "\\end{align}\n",
167 "$$\n"
168 ]
169 },
170 {
171 "cell_type": "markdown",
172 "metadata": {},
173 "source": [
174 "根据贝叶斯公式,可以得到该同学在一家“价格低、口味偏甜、距离远”的店铺,\n",
175 "\n",
176 "下单的后验概率为:\n",
177 "\n",
178 "$$\n",
179 "\\begin{align}\n",
180 "&P(下单|价位=低,口味=偏甜,距离=远)\\\\\n",
181 "=&P(下单)×P(价位=低,口味=偏甜,距离=远|下单)\\\\\n",
182 "=&0.5×0.141\\\\\n",
183 "= &0.0705\n",
184 "\\end{align}\n",
185 "$$\n",
186 "\n",
187 "不下单的后验概率为:\n",
188 "$$\n",
189 "\\begin{align}\n",
190 "&P(不下单|价位=低,口味=偏甜,距离=远)\\\\\n",
191 "=&P(不下单)×P(价位=低,口味=偏甜,距离=远|不下单)\\\\\n",
192 "=&0.5×0.047\\\\\n",
193 "=&0.0235\n",
194 "\\end{align}\n",
195 "$$\n",
196 "\n",
197 "\n",
198 "由此可见,该同学这次会下单的概率大于不下单的概率。\n",
199 "\n",
200 "上面的计算过程进行了一些简化,本来应该计算如下两个公式:\n",
201 "\n",
202 "$$\n",
203 "\\begin{align}\n",
204 "&P(下单|价位=低,口味=偏甜,距离=远)\\\\\n",
205 "=&\\frac{P(下单)×P(价位=低,口味=偏甜,距离=远|下单)}{P(价位=低,口味=偏甜,距离=远)}\\\\\n",
206 "\\end{align}\n",
207 "$$\n",
208 "\n",
209 "$$\n",
210 "\\begin{align}\n",
211 "&P(不下单|价位=低,口味=偏甜,距离=远)\\\\\n",
212 "=&\\frac{P(不下单)×P(价位=低,口味=偏甜,距离=远|不下单)}{P(价位=低,口味=偏甜,距离=远)}\\\\\n",
213 "\\end{align}\n",
214 "$$\n",
215 "\n",
216 "上述两个计算公式分母相同,对计算结果不影响,因此就从计算过程中略去了。"
217 ]
218 },
219 {
220 "cell_type": "markdown",
221 "metadata": {},
222 "source": [
223 "### 实践与体验"
224 ]
225 },
226 {
227 "cell_type": "markdown",
228 "metadata": {},
229 "source": [
230 "#### 利用朴素贝叶斯分类器解决 MNIST 手写体数字识别问题\n",
231 "\n",
232 "**MNIST** 是一个手写体数据集,它包含了各种各样的手写体数字图像及其对应的数字标签。其中每幅手写体图像的大小为 **28×28** ,共有 **784** 个像素点,可记为一个 **784** 维的向量,每个 **784** 维向量对应着一个标签。"
233 ]
234 },
235 {
236 "cell_type": "markdown",
237 "metadata": {},
238 "source": [
239 "本次实验我们利用 **tensorflow** 库来来进行原始数据集的解析和读取,利用 **sklearn** 库来进行特征提取和分类。更多内容可参考**tensorflow** 的[数据集部分](https://www.tensorflow.org/datasets/),sklearn 的 [bayes部分](https://scikit-learn.org/stable/modules/naive_bayes.html)。\n",
240 " \n",
241 "1.在 **Python** 中导入相应库。"
242 ]
243 },
244 {
245 "cell_type": "code",
246 "execution_count": null,
247 "metadata": {},
248 "outputs": [],
249 "source": [
250 "import warnings\n",
251 "warnings.filterwarnings(\"ignore\")\n",
252 "import numpy as np\n",
253 "from tensorflow.keras.datasets import mnist\n",
254 "from sklearn.naive_bayes import BernoulliNB"
255 ]
256 },
257 {
258 "cell_type": "markdown",
259 "metadata": {},
260 "source": [
261 "2.读取 **MNIST** 训练集和测试集。"
262 ]
263 },
264 {
265 "cell_type": "code",
266 "execution_count": null,
267 "metadata": {},
268 "outputs": [],
269 "source": [
270 "print(\"读取数据中 ...\")\n",
271 "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
272 "train_images = train_images.reshape(train_images.shape[0], 784)\n",
273 "test_images = test_images.reshape(test_images.shape[0], 784)\n",
274 "print('读取完毕!')\n"
275 ]
276 },
277 {
278 "cell_type": "markdown",
279 "metadata": {},
280 "source": [
281 "我们使用下面的方法来查看其中几张图片。"
282 ]
283 },
284 {
285 "cell_type": "code",
286 "execution_count": null,
287 "metadata": {},
288 "outputs": [],
289 "source": [
290 "def plot_images(imgs):\n",
291 " \"\"\"绘制几个样本图片\n",
292 " :param show: 是否显示绘图\n",
293 " :return:\n",
294 " \"\"\"\n",
295 " sample_num = min(9, len(imgs))\n",
296 " img_figure = plt.figure(1)\n",
297 " img_figure.set_figwidth(5)\n",
298 " img_figure.set_figheight(5)\n",
299 " for index in range(0, sample_num):\n",
300 " ax = plt.subplot(3, 3, index + 1)\n",
301 " ax.imshow(imgs[index].reshape(28, 28), cmap='gray')\n",
302 " ax.grid(False)\n",
303 " plt.margins(0, 0)\n",
304 " plt.show()\n",
305 "\n",
306 "\n",
307 "plot_images(train_images)"
308 ]
309 },
310 {
311 "cell_type": "markdown",
312 "metadata": {},
313 "source": [
314 "3.根据 **MNIST** 训练集训练朴素贝叶斯分类器"
315 ]
316 },
317 {
318 "cell_type": "code",
319 "execution_count": null,
320 "metadata": {},
321 "outputs": [],
322 "source": [
323 "print(\"初始化并训练贝叶斯模型...\")\n",
324 "classifier_BNB = BernoulliNB()\n",
325 "classifier_BNB.fit(train_images,train_labels)\n",
326 "print('训练完成!')"
327 ]
328 },
329 {
330 "cell_type": "markdown",
331 "metadata": {},
332 "source": [
333 "4.根据训练出的分类器对 **MNIST** 测试集中的图片进行识别,得到预测值。\n"
334 ]
335 },
336 {
337 "cell_type": "code",
338 "execution_count": null,
339 "metadata": {},
340 "outputs": [],
341 "source": [
342 "print(\"测试训练好的贝叶斯模型...\")\n",
343 "test_predict_BNB = classifier_BNB.predict(test_images)\n",
344 "print(\"测试完成!\")"
345 ]
346 },
347 {
348 "cell_type": "markdown",
349 "metadata": {},
350 "source": [
351 "5.将测试图片的预测值与实际值相比较,计算并输出分类器的正确率。"
352 ]
353 },
354 {
355 "cell_type": "code",
356 "execution_count": null,
357 "metadata": {},
358 "outputs": [],
359 "source": [
360 "accuracy = sum(test_predict_BNB==test_labels)/len(test_labels)\n",
361 "print('贝叶斯分类模型在测试集上的准确率为 :',accuracy)"
362 ]
363 },
364 {
365 "cell_type": "markdown",
366 "metadata": {},
367 "source": [
368 "6.对实验结果进行分析比较,列出 **0-9** 不同数字识别的准确率,比较其差异。"
369 ]
370 },
371 {
372 "cell_type": "code",
373 "execution_count": null,
374 "metadata": {},
375 "outputs": [],
376 "source": [
377 "# 记录每个类别的样本的个数,例如 {0:100} 即 数字为 0 的图片有 100 张 \n",
378 "class_num = {}\n",
379 "# 每个类别预测为 0-9 类别的个数,\n",
380 "predict_num = []\n",
381 "# 每个类别预测的准确率\n",
382 "class_accuracy = {}\n",
383 "\n",
384 "for i in range(10):\n",
385 " # 找到类别是 i 的下标\n",
386 " class_is_i_index = np.where(test_labels == i)[0]\n",
387 " # 统计类别是 i 的个数\n",
388 " class_num[i] = len(class_is_i_index)\n",
389 "\n",
390 " # 统计类别 i 预测为 0-9 各个类别的个数\n",
391 " predict_num.append(\n",
392 " [sum(test_predict_BNB[class_is_i_index] == e) for e in range(10)])\n",
393 "\n",
394 " # 统计类别 i 预测的准确率\n",
395 " class_accuracy[i] = round(predict_num[i][i] / class_num[i], 3) * 100\n",
396 "\n",
397 " print(\"数字 %s 的样本个数:%4s,预测正确的个数:%4s,准确率:%.4s%%\" % (\n",
398 " i, class_num[i], predict_num[i][i], class_accuracy[i]))"
399 ]
400 },
401 {
402 "cell_type": "code",
403 "execution_count": null,
404 "metadata": {},
405 "outputs": [],
406 "source": [
407 "import numpy as np\n",
408 "import seaborn as sns\n",
409 "import matplotlib.pyplot as plt\n",
410 "\n",
411 "sns.set(rc={'figure.figsize': (12, 8)})\n",
412 "np.random.seed(0)\n",
413 "uniform_data = predict_num\n",
414 "ax = sns.heatmap(uniform_data, cmap='YlGnBu', vmin=0, vmax=150)\n",
415 "ax.set_xlabel('真实值')\n",
416 "ax.set_ylabel('预测值')\n",
417 "plt.show()"
418 ]
419 },
420 {
421 "cell_type": "markdown",
422 "metadata": {},
423 "source": [
424 "通过热力图,我们看到 3 经常被错认为 5 和 8, 4 和 9 经常互相错认。"
425 ]
426 },
427 {
428 "cell_type": "markdown",
429 "metadata": {},
430 "source": [
431 "我们看看真实标签为 9,但是预测为 4 的错认的照片\n"
432 ]
433 },
434 {
435 "cell_type": "code",
436 "execution_count": null,
437 "metadata": {},
438 "outputs": [],
439 "source": [
440 "def get_imgs(images, true_labels, predict_labels, true_label,\n",
441 " predict_label):\n",
442 " \"\"\"\n",
443 " 从全部图片中按真实标签和预测标签筛选出图片\n",
444 " :param images: 一组图片\n",
445 " :param true_labels: 每张图片的标签\n",
446 " :param predict_labels: 模型预测的每张图片的标签\n",
447 " :param true_label: 希望取得的图片的真实标签\n",
448 " :param predict_label: 希望取得的图片的预测标签\n",
449 " :return: \n",
450 " \"\"\"\n",
451 " # 所有类别为 true_label 的样本的 index 值\n",
452 " true_label_index = set(np.where(true_labels == true_label)[0])\n",
453 " # 所有预测类别为 predict_label 的样本的 index 值\n",
454 " predict_label_index = set(np.where(predict_labels == predict_label)[0])\n",
455 " # 取交集,即为真实类别为 true_label, 预测结果为 predict_label 的样本的 index 值\n",
456 " res = list(true_label_index & predict_label_index)\n",
457 " return images[res]\n"
458 ]
459 },
460 {
461 "cell_type": "code",
462 "execution_count": null,
463 "metadata": {},
464 "outputs": [],
465 "source": [
466 "imgs = get_imgs(test_images, test_labels, test_predict_BNB, 9, 4)\n",
467 "plot_images(imgs)"
468 ]
469 },
470 {
471 "cell_type": "markdown",
472 "metadata": {},
473 "source": [
474 "你在上面的试验中观察到了什么?在下方列出模型对 0-9 不同数字识别的准确率,并比较其差异。"
475 ]
476 },
477 {
478 "cell_type": "markdown",
479 "metadata": {},
480 "source": [
481 "答案:(在此处填写你的答案。)"
482 ]
483 }
484 ],
485 "metadata": {
486 "kernelspec": {
487 "display_name": "Python 3",
488 "language": "python",
489 "name": "python3"
490 },
491 "language_info": {
492 "codemirror_mode": {
493 "name": "ipython",
494 "version": 3
495 },
496 "file_extension": ".py",
497 "mimetype": "text/x-python",
498 "name": "python",
499 "nbconvert_exporter": "python",
500 "pygments_lexer": "ipython3",
501 "version": "3.5.2"
502 }
503 },
504 "nbformat": 4,
505 "nbformat_minor": 2
506 }
0 {
1 "cells": [
2 {
3 "cell_type": "markdown",
4 "metadata": {},
5 "source": [
6 "# 5.4 贝叶斯分析\n",
7 "贝叶斯分析是一种根据概率统计知识对数据进行分析的方法,属于统计学分类的范畴。"
8 ]
9 },
10 {
11 "cell_type": "markdown",
12 "metadata": {},
13 "source": [
14 "## 5.4.1 贝叶斯公式\n",
15 "\n",
16 "- **频率学派**:从历史数据中计算某个事件的概率,认为只要采样足够多,则事件发生的频率就可以无限逼近真实概率。\n",
17 "- **贝叶斯学派**:认为某个事件发生的概率不仅与先前这个事件发生的概率相关(称为**先验概率**),也与后期计算该事件概率时所观测的“新近”信息有关(称为**似然概率**)\n",
18 "\n",
19 "贝叶斯概率计算公式表达:\n",
20 "后验概率 = 先验概率 × 似然概率\n",
21 "\n",
22 "**条件概率**:\n",
23 "\n",
24 "$P(A|B)$:表示事件 $B$ 发生的前提下,事件 $A$ 发生的概率\n",
25 "\n",
26 "$$P(A|B)=\\frac{P(A ∩ B)}{P(B)}$$\n",
27 "\n",
28 "$P(B|A)$:表示事件 $A$ 发生的前提下,事件 $B$ 发生的概率\n",
29 "\n",
30 "$$P(B|A)=\\frac{P(A ∩ B)}{P(A)}$$\n",
31 "\n",
32 "由于:\n",
33 "\n",
34 "$$P(B|A){P(A)}=P(A|B){P(B)}={P(A ∩ B)}$$\n",
35 "\n",
36 "可得**贝叶斯公式**:\n",
37 "\n",
38 "$$P(A|B) = \\frac{P(B|A)P(A)}{P(B)}$$\n",
39 "\n",
40 "其中:\n",
41 "- $P(A)$ 是事件 $A$ 发生的先验概率,与事件 $B$ 是否发生无关;\n",
42 "- $P(B|A)$是事件 $A$ 发生前提下,事件 $B$ 发生的概率,也称为**似然概率**; \n",
43 "- $P(B)$ 是事件 $B$ 发生的先验概率,也称为**标准化常量**;\n",
44 "- $P(A|B)$是事件 $B$ 发生前提下,事件 $A$ 发生的概率,也是 $A$ 的后验概率。 "
45 ]
46 },
47 {
48 "cell_type": "markdown",
49 "metadata": {},
50 "source": [
51 "## 5.4.2 贝叶斯推断\n",
52 "贝叶斯推断是一种基于贝叶斯公式进行分析的统计学方法。"
53 ]
54 },
55 {
56 "cell_type": "markdown",
57 "metadata": {},
58 "source": [
59 "小例子:根据邮件中的 “红包” 字样判别该邮件是不是垃圾邮件"
60 ]
61 },
62 {
63 "cell_type": "code",
64 "execution_count": null,
65 "metadata": {},
66 "outputs": [],
67 "source": [
68 "# 广告邮件的数量 \n",
69 "ad_number = 4000\n",
70 "# 正常邮件的数量\n",
71 "normal_number = 6000\n",
72 "\n",
73 "# 所有广告邮件中,出现 “红包” 关键词的邮件的数量\n",
74 "ad_hongbao_number = 1000\n",
75 "# 所有正常邮件中,出现 “红包” 关键词的邮件的数量\n",
76 "normal_hongbao_number = 6\n",
77 "\n",
78 "# 用户收到广告邮件的先验概率为\n",
79 "P_ad = ad_number / (ad_number + normal_number)\n",
80 "print(\"用户收到广告邮件的先验概率为 \" + str(P_ad))\n",
81 "\n",
82 "# 用户收到正常邮件的先验概率为\n",
83 "P_normal = normal_number / (ad_number + normal_number)\n",
84 "print(\"用户收到正常邮件的先验概率为 \" + str(P_normal))\n",
85 "\n",
86 "# 红包出现的概率\n",
87 "P_hongbao = (normal_hongbao_number + ad_hongbao_number) / (\n",
88 " ad_number + normal_number)\n",
89 "print(\"邮件包含红包的先验概率为 \" + str(P_hongbao))\n",
90 "\n",
91 "# 广告邮件中出现 “红包” 关键词的条件概率\n",
92 "P_hongbao_ad = ad_hongbao_number / ad_number\n",
93 "print(\"广告邮件中出现 “红包” 关键词的条件概率为 \" + str(P_hongbao_ad))\n",
94 "\n",
95 "# 正确邮件中出现 “红包” 关键词的条件概率\n",
96 "P_hongbao_normal = normal_hongbao_number / normal_number\n",
97 "print(\"正常邮件中出现 “红包” 关键词的条件概率为 \" + str(P_hongbao_normal))\n",
98 "\n",
99 "# 根据贝叶斯定理可得\n",
100 "# 当邮件中出现 “红包” ,其为广告邮件的后验概率\n",
101 "P_ad_hongbao = P_ad * P_hongbao_ad / P_hongbao\n",
102 "print(\"当邮件中出现 “红包” ,其为广告邮件的后验概率为 \" + str(P_ad_hongbao))\n",
103 "\n",
104 "# 当邮件中出现 “红包” ,其为正常邮件的后验概率\n",
105 "P_normal_hongbao = P_normal * P_hongbao_normal / P_hongbao\n",
106 "print(\"当邮件中出现 “红包” ,其为正常邮件的后验概率为 \" + str(P_normal_hongbao))"
107 ]
108 },
109 {
110 "cell_type": "markdown",
111 "metadata": {},
112 "source": [
113 "## 5.4.3 朴素贝叶斯分类器 \n",
114 "一种常用的分类算法,其假设**样本各个特征之间相互独立、互不影响**。"
115 ]
116 },
117 {
118 "cell_type": "markdown",
119 "metadata": {},
120 "source": [
121 "小例子:预测同学会不会在某店铺订餐。\n",
122 "\n",
123 "**目标**:根据某同学的订单记录,如果向他推荐一家“价位低、口味偏甜、距离远”的店铺,判断他会下单吗?\n",
124 "\n",
125 "**数据**:该同学的下单记录如下\n",
126 "\n",
127 "|店铺价位|店铺口味|店铺距离|是否下单|\n",
128 "|:--:|:--:|:--:|:--:|\n",
129 "|高|偏甜|近|是|\n",
130 "|高|清淡|近|否|\n",
131 "|高|偏辣|远|否|\n",
132 "|高|偏甜|远|否|\n",
133 "|低|偏甜|近|是|\n",
134 "|低|偏甜|近|是|\n",
135 "|低|清淡|远|否|\n",
136 "|低|偏辣|远|是|\n"
137 ]
138 },
139 {
140 "cell_type": "markdown",
141 "metadata": {},
142 "source": [
143 "该同学在收到8次推荐后,下单4次和没有下单4次,则其“下单”,“不下单”的概率: \n",
144 "$$P(下单) = \\frac{4}{8}=0.5$$ \n",
145 "$$P(不下单) = \\frac{4}{8}=0.5$$"
146 ]
147 },
148 {
149 "cell_type": "markdown",
150 "metadata": {},
151 "source": [
152 "该同学对 “价位低、口味偏甜、距离远” 这次推荐的 “下单” 或 “不下单” 的似然概率为(注意基本假设是店铺价位、口味、距离这些特质中间互相独立,互不影响):\n",
153 "\n",
154 "$$\n",
155 "\\begin{align}\n",
156 "&P(价位=低,口味=偏甜,距离=远|下单)\\\\\n",
157 "=&P(价位=低|下单)×P(口味=偏甜|下单)×P(距离=远|下单)\\\\\n",
158 "=&\\frac{3}{4}×\\frac{3}{4}×\\frac{1}{4}\\\\\n",
159 "≈ & 0.141\n",
160 "& \\\\\n",
161 "& \\\\\n",
162 "& P(价位=低,口味=偏甜,距离=远|不下单)\\\\\n",
163 "=&P(价位=低|不下单)×P(口味=偏甜|不下单)×P(距离=远|不下单)\\\\\n",
164 "=&\\frac{1}{4}×\\frac{1}{4}×\\frac{3}{4}\\\\\n",
165 "≈ &0.047\n",
166 "\\end{align}\n",
167 "$$\n"
168 ]
169 },
170 {
171 "cell_type": "markdown",
172 "metadata": {},
173 "source": [
174 "根据贝叶斯公式,可以得到该同学在一家“价格低、口味偏甜、距离远”的店铺,\n",
175 "\n",
176 "下单的后验概率为:\n",
177 "\n",
178 "$$\n",
179 "\\begin{align}\n",
180 "&P(下单|价位=低,口味=偏甜,距离=远)\\\\\n",
181 "=&P(下单)×P(价位=低,口味=偏甜,距离=远|下单)\\\\\n",
182 "=&0.5×0.141\\\\\n",
183 "= &0.0705\n",
184 "\\end{align}\n",
185 "$$\n",
186 "\n",
187 "不下单的后验概率为:\n",
188 "$$\n",
189 "\\begin{align}\n",
190 "&P(不下单|价位=低,口味=偏甜,距离=远)\\\\\n",
191 "=&P(不下单)×P(价位=低,口味=偏甜,距离=远|不下单)\\\\\n",
192 "=&0.5×0.047\\\\\n",
193 "=&0.0235\n",
194 "\\end{align}\n",
195 "$$\n",
196 "\n",
197 "\n",
198 "由此可见,该同学这次会下单的概率大于不下单的概率。\n",
199 "\n",
200 "上面的计算过程进行了一些简化,本来应该计算如下两个公式:\n",
201 "\n",
202 "$$\n",
203 "\\begin{align}\n",
204 "&P(下单|价位=低,口味=偏甜,距离=远)\\\\\n",
205 "=&\\frac{P(下单)×P(价位=低,口味=偏甜,距离=远|下单)}{P(价位=低,口味=偏甜,距离=远)}\\\\\n",
206 "\\end{align}\n",
207 "$$\n",
208 "\n",
209 "$$\n",
210 "\\begin{align}\n",
211 "&P(不下单|价位=低,口味=偏甜,距离=远)\\\\\n",
212 "=&\\frac{P(不下单)×P(价位=低,口味=偏甜,距离=远|不下单)}{P(价位=低,口味=偏甜,距离=远)}\\\\\n",
213 "\\end{align}\n",
214 "$$\n",
215 "\n",
216 "上述两个计算公式分母相同,对计算结果不影响,因此就从计算过程中略去了。"
217 ]
218 },
219 {
220 "cell_type": "markdown",
221 "metadata": {},
222 "source": [
223 "### 实践与体验"
224 ]
225 },
226 {
227 "cell_type": "markdown",
228 "metadata": {},
229 "source": [
230 "#### 利用朴素贝叶斯分类器解决 MNIST 手写体数字识别问题\n",
231 "\n",
232 "**MNIST** 是一个手写体数据集,它包含了各种各样的手写体数字图像及其对应的数字标签。其中每幅手写体图像的大小为 **28×28** ,共有 **784** 个像素点,可记为一个 **784** 维的向量,每个 **784** 维向量对应着一个标签。"
233 ]
234 },
235 {
236 "cell_type": "markdown",
237 "metadata": {},
238 "source": [
239 "本次实验我们利用 **tensorflow** 库来来进行原始数据集的解析和读取,利用 **sklearn** 库来进行特征提取和分类。更多内容可参考**tensorflow** 的[数据集部分](https://www.tensorflow.org/datasets/),sklearn 的 [bayes部分](https://scikit-learn.org/stable/modules/naive_bayes.html)。\n",
240 " \n",
241 "1.在 **Python** 中导入相应库。"
242 ]
243 },
244 {
245 "cell_type": "code",
246 "execution_count": null,
247 "metadata": {},
248 "outputs": [],
249 "source": [
250 "import warnings\n",
251 "warnings.filterwarnings(\"ignore\")\n",
252 "import numpy as np\n",
253 "from tensorflow.keras.datasets import mnist\n",
254 "from sklearn.naive_bayes import BernoulliNB"
255 ]
256 },
257 {
258 "cell_type": "markdown",
259 "metadata": {},
260 "source": [
261 "2.读取 **MNIST** 训练集和测试集。"
262 ]
263 },
264 {
265 "cell_type": "code",
266 "execution_count": null,
267 "metadata": {},
268 "outputs": [],
269 "source": [
270 "print(\"读取数据中 ...\")\n",
271 "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n",
272 "train_images = train_images.reshape(train_images.shape[0], 784)\n",
273 "test_images = test_images.reshape(test_images.shape[0], 784)\n",
274 "print('读取完毕!')\n"
275 ]
276 },
277 {
278 "cell_type": "markdown",
279 "metadata": {},
280 "source": [
281 "我们使用下面的方法来查看其中几张图片。"
282 ]
283 },
284 {
285 "cell_type": "code",
286 "execution_count": null,
287 "metadata": {},
288 "outputs": [],
289 "source": [
290 "def plot_images(imgs):\n",
291 " \"\"\"绘制几个样本图片\n",
292 " :param show: 是否显示绘图\n",
293 " :return:\n",
294 " \"\"\"\n",
295 " sample_num = min(9, len(imgs))\n",
296 " img_figure = plt.figure(1)\n",
297 " img_figure.set_figwidth(5)\n",
298 " img_figure.set_figheight(5)\n",
299 " for index in range(0, sample_num):\n",
300 " ax = plt.subplot(3, 3, index + 1)\n",
301 " ax.imshow(imgs[index].reshape(28, 28), cmap='gray')\n",
302 " ax.grid(False)\n",
303 " plt.margins(0, 0)\n",
304 " plt.show()\n",
305 "\n",
306 "\n",
307 "plot_images(train_images)"
308 ]
309 },
310 {
311 "cell_type": "markdown",
312 "metadata": {},
313 "source": [
314 "3.根据 **MNIST** 训练集训练朴素贝叶斯分类器"
315 ]
316 },
317 {
318 "cell_type": "code",
319 "execution_count": null,
320 "metadata": {},
321 "outputs": [],
322 "source": [
323 "print(\"初始化并训练贝叶斯模型...\")\n",
324 "classifier_BNB = BernoulliNB()\n",
325 "classifier_BNB.fit(train_images,train_labels)\n",
326 "print('训练完成!')"
327 ]
328 },
329 {
330 "cell_type": "markdown",
331 "metadata": {},
332 "source": [
333 "4.根据训练出的分类器对 **MNIST** 测试集中的图片进行识别,得到预测值。\n"
334 ]
335 },
336 {
337 "cell_type": "code",
338 "execution_count": null,
339 "metadata": {},
340 "outputs": [],
341 "source": [
342 "print(\"测试训练好的贝叶斯模型...\")\n",
343 "test_predict_BNB = classifier_BNB.predict(test_images)\n",
344 "print(\"测试完成!\")"
345 ]
346 },
347 {
348 "cell_type": "markdown",
349 "metadata": {},
350 "source": [
351 "5.将测试图片的预测值与实际值相比较,计算并输出分类器的正确率。"
352 ]
353 },
354 {
355 "cell_type": "code",
356 "execution_count": null,
357 "metadata": {},
358 "outputs": [],
359 "source": [
360 "accuracy = sum(test_predict_BNB==test_labels)/len(test_labels)\n",
361 "print('贝叶斯分类模型在测试集上的准确率为 :',accuracy)"
362 ]
363 },
364 {
365 "cell_type": "markdown",
366 "metadata": {},
367 "source": [
368 "6.对实验结果进行分析比较,列出 **0-9** 不同数字识别的准确率,比较其差异。"
369 ]
370 },
371 {
372 "cell_type": "code",
373 "execution_count": null,
374 "metadata": {},
375 "outputs": [],
376 "source": [
377 "# 记录每个类别的样本的个数,例如 {0:100} 即 数字为 0 的图片有 100 张 \n",
378 "class_num = {}\n",
379 "# 每个类别预测为 0-9 类别的个数,\n",
380 "predict_num = []\n",
381 "# 每个类别预测的准确率\n",
382 "class_accuracy = {}\n",
383 "\n",
384 "for i in range(10):\n",
385 " # 找到类别是 i 的下标\n",
386 " class_is_i_index = np.where(test_labels == i)[0]\n",
387 " # 统计类别是 i 的个数\n",
388 " class_num[i] = len(class_is_i_index)\n",
389 "\n",
390 " # 统计类别 i 预测为 0-9 各个类别的个数\n",
391 " predict_num.append(\n",
392 " [sum(test_predict_BNB[class_is_i_index] == e) for e in range(10)])\n",
393 "\n",
394 " # 统计类别 i 预测的准确率\n",
395 " class_accuracy[i] = round(predict_num[i][i] / class_num[i], 3) * 100\n",
396 "\n",
397 " print(\"数字 %s 的样本个数:%4s,预测正确的个数:%4s,准确率:%.4s%%\" % (\n",
398 " i, class_num[i], predict_num[i][i], class_accuracy[i]))"
399 ]
400 },
401 {
402 "cell_type": "code",
403 "execution_count": null,
404 "metadata": {},
405 "outputs": [],
406 "source": [
407 "import numpy as np\n",
408 "import seaborn as sns\n",
409 "import matplotlib.pyplot as plt\n",
410 "\n",
411 "sns.set(rc={'figure.figsize': (12, 8)})\n",
412 "np.random.seed(0)\n",
413 "uniform_data = predict_num\n",
414 "ax = sns.heatmap(uniform_data, cmap='YlGnBu', vmin=0, vmax=150)\n",
415 "ax.set_xlabel('真实值')\n",
416 "ax.set_ylabel('预测值')\n",
417 "plt.show()"
418 ]
419 },
420 {
421 "cell_type": "markdown",
422 "metadata": {},
423 "source": [
424 "通过热力图,我们看到 3 经常被错认为 5 和 8, 4 和 9 经常互相错认。"
425 ]
426 },
427 {
428 "cell_type": "markdown",
429 "metadata": {},
430 "source": [
431 "我们看看真实标签为 9,但是预测为 4 的错认的照片\n"
432 ]
433 },
434 {
435 "cell_type": "code",
436 "execution_count": null,
437 "metadata": {},
438 "outputs": [],
439 "source": [
440 "def get_imgs(images, true_labels, predict_labels, true_label,\n",
441 " predict_label):\n",
442 " \"\"\"\n",
443 " 从全部图片中按真实标签和预测标签筛选出图片\n",
444 " :param images: 一组图片\n",
445 " :param true_labels: 每张图片的标签\n",
446 " :param predict_labels: 模型预测的每张图片的标签\n",
447 " :param true_label: 希望取得的图片的真实标签\n",
448 " :param predict_label: 希望取得的图片的预测标签\n",
449 " :return: \n",
450 " \"\"\"\n",
451 " # 所有类别为 true_label 的样本的 index 值\n",
452 " true_label_index = set(np.where(true_labels == true_label)[0])\n",
453 " # 所有预测类别为 predict_label 的样本的 index 值\n",
454 " predict_label_index = set(np.where(predict_labels == predict_label)[0])\n",
455 " # 取交集,即为真实类别为 true_label, 预测结果为 predict_label 的样本的 index 值\n",
456 " res = list(true_label_index & predict_label_index)\n",
457 " return images[res]\n"
458 ]
459 },
460 {
461 "cell_type": "code",
462 "execution_count": null,
463 "metadata": {},
464 "outputs": [],
465 "source": [
466 "imgs = get_imgs(test_images, test_labels, test_predict_BNB, 9, 4)\n",
467 "plot_images(imgs)"
468 ]
469 }
470 ],
471 "metadata": {
472 "kernelspec": {
473 "display_name": "Python 3",
474 "language": "python",
475 "name": "python3"
476 },
477 "language_info": {
478 "codemirror_mode": {
479 "name": "ipython",
480 "version": 3
481 },
482 "file_extension": ".py",
483 "mimetype": "text/x-python",
484 "name": "python",
485 "nbconvert_exporter": "python",
486 "pygments_lexer": "ipython3",
487 "version": "3.5.2"
488 }
489 },
490 "nbformat": 4,
491 "nbformat_minor": 2
492 }
0 {
1 "cells": [
2 {
3 "cell_type": "markdown",
4 "metadata": {},
5 "source": [
6 "# 5.5 神经网络学习"
7 ]
8 },
9 {
10 "cell_type": "markdown",
11 "metadata": {},
12 "source": [
13 "神经网络模拟人脑神经元的连接来达到学习功能,通过逐层抽象将输入数据逐层映射为概念等高等语义。"
14 ]
15 },
16 {
17 "cell_type": "markdown",
18 "metadata": {},
19 "source": [
20 "## 5.5.1 人脑神经机制"
21 ]
22 },
23 {
24 "cell_type": "markdown",
25 "metadata": {},
26 "source": [
27 "人眼在辨识图片时,会先提取边缘特征,再识别部件,最后再得到最高层的模式。也就是说,高层的特征是低层特征的组合,从低层到高层的特征表示越来越抽象,越来越能表现语义。"
28 ]
29 },
30 {
31 "cell_type": "markdown",
32 "metadata": {},
33 "source": [
34 "<img src=\"http://imgbed.momodel.cn//20200103102429.png\" width=300>"
35 ]
36 },
37 {
38 "cell_type": "markdown",
39 "metadata": {},
40 "source": [
41 "## 5.5.2 感知机模型"
42 ]
43 },
44 {
45 "cell_type": "markdown",
46 "metadata": {},
47 "source": [
48 "**感知机模型**:\n",
49 "\n",
50 "<img src=\"http://imgbed.momodel.cn/感知器模型.png\" width=300/>"
51 ]
52 },
53 {
54 "cell_type": "markdown",
55 "metadata": {},
56 "source": [
57 "**输入项**:3个,$x_1,x_2,x_3$ \n",
58 "**神经元**:1个,用圆圈表示 \n",
59 "**权重**:每个输入项均通过权重与神经元相连(比如 $w_i$ 是 $x_i$ 与神经元相连的权重) \n",
60 "**输出**:1个\n",
61 "\n",
62 "\n",
63 "**工作方法**:\n",
64 "+ 计算输入项传递给神经元的信息加权总和,即:$y_{sum} = w_1x_1+w_2x_2+w_3x_3$\n",
65 "+ 如果 $y_{sum}$ 大于某个预定阀值(比如 0.5),则输出为 1,否则为 0 。\n"
66 ]
67 },
68 {
69 "cell_type": "markdown",
70 "metadata": {},
71 "source": [
72 "在输出的判断上,其实不仅可以简单的按照阈值来判断,可以通过一个函数来进行计算,这个函数称为激活函数。常见的激活函数有: sigmoid,tanh,relu 等。下面我们看看这些激活函数的曲线图。"
73 ]
74 },
75 {
76 "cell_type": "code",
77 "execution_count": null,
78 "metadata": {},
79 "outputs": [],
80 "source": [
81 "import numpy as np\n",
82 "import matplotlib.pyplot as plt\n",
83 "import warnings\n",
84 "warnings.filterwarnings(\"ignore\")\n",
85 "\n",
86 "\n",
87 "def plot_activation_function(activation_function):\n",
88 " \"\"\"\n",
89 " 绘制激活函数\n",
90 " :param activation_function: 激活函数名\n",
91 " :return:\n",
92 " \"\"\"\n",
93 " x = np.arange(-10, 10, 0.1)\n",
94 " y_activation_function = activation_function(x)\n",
95 "\n",
96 " # 绘制坐标轴\n",
97 " ax = plt.gca()\n",
98 " ax.spines['right'].set_color('none')\n",
99 " ax.spines['top'].set_color('none')\n",
100 " ax.xaxis.set_ticks_position('bottom')\n",
101 " ax.yaxis.set_ticks_position('left')\n",
102 " ax.spines['bottom'].set_position(('data', 0))\n",
103 " ax.spines['left'].set_position(('data', 0))\n",
104 "\n",
105 " # 绘制曲线图\n",
106 " plt.plot(x, y_activation_function)\n",
107 " \n",
108 " # 展示函数图像\n",
109 " plt.show()"
110 ]
111 },
112 {
113 "cell_type": "code",
114 "execution_count": null,
115 "metadata": {},
116 "outputs": [],
117 "source": [
118 "def sigmoid(x):\n",
119 " \"\"\"\n",
120 " sigmoid函数\n",
121 " :param x: np.array 格式数据\n",
122 " :return: sigmoid 函数\n",
123 " \"\"\"\n",
124 " return 1 / (1 + np.exp(-x))\n",
125 "\n",
126 "# 绘制 sigmoid 函数图像\n",
127 "plot_activation_function(sigmoid)"
128 ]
129 },
130 {
131 "cell_type": "code",
132 "execution_count": null,
133 "metadata": {},
134 "outputs": [],
135 "source": [
136 "def tanh(x):\n",
137 " \"\"\"\n",
138 " tanh函数\n",
139 " :param x: np.array 格式数据\n",
140 " :return: tanh 函数\n",
141 " \"\"\"\n",
142 " return ((np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x)))\n",
143 "\n",
144 "# 绘制 tanh 函数图像\n",
145 "plot_activation_function(tanh)"
146 ]
147 },
148 {
149 "cell_type": "code",
150 "execution_count": null,
151 "metadata": {},
152 "outputs": [],
153 "source": [
154 "def relu(x):\n",
155 " \"\"\"\n",
156 " relu 函数\n",
157 " :param x: np.array 格式数据\n",
158 " :return: relu 函数\n",
159 " \"\"\"\n",
160 " temp = np.zeros_like(x)\n",
161 " if_bigger_zero = (x > temp)\n",
162 " return x * if_bigger_zero\n",
163 "\n",
164 "# 绘制 relu 函数\n",
165 "plot_activation_function(relu)"
166 ]
167 },
168 {
169 "cell_type": "markdown",
170 "metadata": {},
171 "source": [
172 "我们根据上面的定义可以编写一个简单的感知机模型。"
173 ]
174 },
175 {
176 "cell_type": "code",
177 "execution_count": null,
178 "metadata": {},
179 "outputs": [],
180 "source": [
181 "def perceptron(x, w, threshold):\n",
182 " \"\"\"\n",
183 " 感知机模型\n",
184 " :param x: 输入数据 np.array 格式\n",
185 " :param w: 权重 np.array 格式,需要与 x 一一对应\n",
186 " :param threshold: 阀值\n",
187 " :return: 0或者1\n",
188 " \"\"\"\n",
189 " x = np.array(x)\n",
190 " w = np.array(w)\n",
191 " y_sum = np.sum(w * x)\n",
192 " # 大于阀值返回 1,否则返回 0\n",
193 " return 1 if y_sum > threshold else 0\n",
194 "\n",
195 "\n",
196 "# 输入数据\n",
197 "x = np.array([1, 1, 4])\n",
198 "# 输入权重\n",
199 "w = np.array([0.5, 0.2, 0.3])\n",
200 "# 返回结果\n",
201 "perceptron(x, w, 0.8)"
202 ]
203 },
204 {
205 "cell_type": "markdown",
206 "metadata": {},
207 "source": [
208 "## 5.5.3 神经网络"
209 ]
210 },
211 {
212 "cell_type": "markdown",
213 "metadata": {},
214 "source": [
215 "<img src=\"http://imgbed.momodel.cn//20200103111837.png\" width=400>\n",
216 "\n",
217 "与感知机的不同,神经网络:\n",
218 "+ 输入层和输出层之间存在若干隐藏层。\n",
219 "+ 每个隐藏层中包含若干神经元。\n"
220 ]
221 },
222 {
223 "cell_type": "markdown",
224 "metadata": {},
225 "source": [
226 "## 5.5.4 搭建神经网络"
227 ]
228 },
229 {
230 "cell_type": "markdown",
231 "metadata": {},
232 "source": [
233 "采用 **keras** 框架搭建一个神经网络实现手写体数字识别问题。 \n",
234 "1. 导入相关包"
235 ]
236 },
237 {
238 "cell_type": "code",
239 "execution_count": null,
240 "metadata": {},
241 "outputs": [],
242 "source": [
243 "import keras\n",
244 "from keras.datasets import mnist\n",
245 "from keras.models import Sequential\n",
246 "from keras.layers.core import Dense,Activation,Dropout\n",
247 "from keras.utils import np_utils\n",
248 "import warnings\n",
249 "warnings.filterwarnings(\"ignore\")\n",
250 "!mkdir -p ~/.keras/datasets\n",
251 "!cp ./mnist.npz ~/.keras/datasets/mnist.npz"
252 ]
253 },
254 {
255 "cell_type": "markdown",
256 "metadata": {},
257 "source": [
258 "2. 下载 **MNIST** 数据集并将它们转换为模型所能使用的格式。"
259 ]
260 },
261 {
262 "cell_type": "code",
263 "execution_count": null,
264 "metadata": {},
265 "outputs": [],
266 "source": [
267 "# 获取数据\n",
268 "(X_train, y_train),(X_test,y_test) = mnist.load_data()\n",
269 "\n",
270 "# 将训练集数据形状从(60000,28,28)修改为(60000,784)\n",
271 "X_train = X_train.reshape(len(X_train),-1)\n",
272 "X_test = X_test.reshape(len(X_test),-1)\n",
273 "\n",
274 "# 将数据集图像像素点的数据类型从 uint8 修改为 float32\n",
275 "X_train = X_train.astype('float32')\n",
276 "X_test = X_test.astype('float32')\n",
277 "\n",
278 "# 把数据集图像的像素值从 0-255 放缩到[-1,1]之间\n",
279 "X_train = (X_train - 127)/127\n",
280 "X_test = (X_test - 127)/127\n",
281 "\n",
282 "# 数据集类别个数\n",
283 "nb_classes = 10\n",
284 "\n",
285 "# 把 y_train 和 y_test 变成了 one-hot 的形式,即之前是 0-9 的一个数值, \n",
286 "# 现在是一个大小为 10 的向量,它属于哪个数字,就在哪个位置为 1,其他位置都是 0。\n",
287 "y_train = np_utils.to_categorical(y_train,nb_classes)\n",
288 "y_test = np_utils.to_categorical(y_test,nb_classes)"
289 ]
290 },
291 {
292 "cell_type": "markdown",
293 "metadata": {},
294 "source": [
295 "3. 搭建神经网络模型"
296 ]
297 },
298 {
299 "cell_type": "code",
300 "execution_count": null,
301 "metadata": {},
302 "outputs": [],
303 "source": [
304 "def create_model():\n",
305 " \"\"\"\n",
306 " 采用 keras 搭建神经网络模型\n",
307 " :return: 神经网络模型\n",
308 " \"\"\"\n",
309 " # 选择模型,选择序贯模型(Sequential())\n",
310 " model = Sequential()\n",
311 " \n",
312 " # 添加全连接层,共 512 个神经元\n",
313 " model.add(Dense(512,input_shape=(784,),kernel_initializer='he_normal'))\n",
314 " \n",
315 " # 添加激活层,激活函数选择 relu \n",
316 " model.add(Activation('relu'))\n",
317 " \n",
318 " # 添加全连接层,共 512 个神经元\n",
319 " model.add(Dense(512,kernel_initializer='he_normal'))\n",
320 " \n",
321 " # 添加激活层,激活函数选择 relu \n",
322 " model.add(Activation('relu'))\n",
323 " \n",
324 " # 添加全连接层,共 10 个神经元\n",
325 " model.add(Dense(nb_classes))\n",
326 " \n",
327 " # 添加激活层,激活函数选择 softmax\n",
328 " model.add(Activation('softmax'))\n",
329 " \n",
330 " return model\n",
331 "\n",
332 "model = create_model()"
333 ]
334 },
335 {
336 "cell_type": "markdown",
337 "metadata": {},
338 "source": [
339 "4. 训练和测试神经网络模型"
340 ]
341 },
342 {
343 "cell_type": "code",
344 "execution_count": null,
345 "metadata": {},
346 "outputs": [],
347 "source": [
348 "def fit_and_predict(model, model_path):\n",
349 " \"\"\"\n",
350 " 训练模型、模型评估、保存模型\n",
351 " :param model: 搭建好的模型\n",
352 " :param model_path:保存模型路径\n",
353 " :return:\n",
354 " \"\"\"\n",
355 " # 编译模型\n",
356 " model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])\n",
357 " \n",
358 " # 模型训练\n",
359 " model.fit(X_train, y_train, epochs=5, batch_size=64, verbose=1, validation_split=0.05)\n",
360 " \n",
361 " # 保存模型\n",
362 " model.save(model_path)\n",
363 " \n",
364 " # 模型评估,获取测试集的损失值和准确率\n",
365 " loss, accuracy = model.evaluate(X_test, y_test)\n",
366 "\n",
367 " # 打印结果\n",
368 " print('Test loss:', loss)\n",
369 " print(\"Accuracy:\", accuracy)\n",
370 "\n",
371 "# 训练模型和评估模型\n",
372 "fit_and_predict(model, model_path='./model.h5')"
373 ]
374 },
375 {
376 "cell_type": "markdown",
377 "metadata": {},
378 "source": [
379 "### 实践与体验\n",
380 "#### 调节神经网络结构和参数\n",
381 "\n",
382 "1. 将两层隐藏层改为一层,训练模型并在测试集上测试,得出准确率。\n"
383 ]
384 },
385 {
386 "cell_type": "code",
387 "execution_count": null,
388 "metadata": {},
389 "outputs": [],
390 "source": [
391 "def create_model1():\n",
392 " \"\"\"\n",
393 " 搭建神经网络模型 model1,比 model 少一层隐藏层\n",
394 " :return: 模型 model1\n",
395 " \"\"\"\n",
396 " # todo 参考上面的教程搭建只有一个隐藏层的神经网络模型\n",
397 "\n",
398 " return model\n",
399 "\n",
400 "# 搭建神经网络\n",
401 "model1 = create_model1()\n",
402 "\n",
403 "# 训练神经网络模型,保存模型和评估模型\n",
404 "fit_and_predict(model1, model_path='./model1.h5')"
405 ]
406 },
407 {
408 "cell_type": "markdown",
409 "metadata": {},
410 "source": [
411 "2. 修改两层隐藏层神经元的数量,然后训练模型得出准确率。"
412 ]
413 },
414 {
415 "cell_type": "code",
416 "execution_count": null,
417 "metadata": {},
418 "outputs": [],
419 "source": [
420 "def create_model2():\n",
421 " \"\"\"\n",
422 " 搭建神经网络模型 model2,隐藏层的神经元数目比 model 少一半\n",
423 " :return: 神经网络模型 model2\n",
424 " \"\"\"\n",
425 " # todo 参考上面的教程搭建神经元数目比原始模型少一半的神经网络模型\n",
426 "\n",
427 " return model\n",
428 "\n",
429 "# 搭建神经网络模型\n",
430 "model2 = create_model2()\n",
431 "\n",
432 "# 训练神经网络模型,保存模型并评估模型\n",
433 "fit_and_predict(model2,model_path='./model2.h5')\n",
434 "\n"
435 ]
436 },
437 {
438 "cell_type": "markdown",
439 "metadata": {},
440 "source": [
441 "3. 输入一个手写数字,比较三种模型输出结果的差异,对其差异进行分析解释。"
442 ]
443 },
444 {
445 "cell_type": "code",
446 "execution_count": null,
447 "metadata": {},
448 "outputs": [],
449 "source": [
450 "import numpy as np\n",
451 "np.set_printoptions(suppress=True)\n",
452 "from keras.models import load_model\n",
453 "\n",
454 "# 加载模型\n",
455 "model = load_model('./model.h5')\n",
456 "model1 = load_model('./model1.h5')\n",
457 "model2 = load_model('./model2.h5') "
458 ]
459 },
460 {
461 "cell_type": "code",
462 "execution_count": null,
463 "metadata": {},
464 "outputs": [],
465 "source": [
466 "# 预测结果\n",
467 "predict_results = np.round(model.predict(X_test)[0],3)\n",
468 "predict_results1 = np.round(model1.predict(X_test)[0],3)\n",
469 "predict_results2 = np.round(model2.predict(X_test)[0],3)\n",
470 "\n",
471 "# 打印预测结果\n",
472 "print('原始模型\\n其各类别预测概率:%s,预测值: %s,真实值:%s\\n' % (predict_results,np.argmax(predict_results),np.argmax(y_test[0])))\n",
473 "print('只有一个隐藏层的模型\\n其各类别各类别预测概率:%s,预测值: %s,真实值:%s\\n' % (predict_results1,np.argmax(predict_results1),np.argmax(y_test[0])))\n",
474 "print('隐藏神经元数量更改后的模型\\n其各类别预测概率:%s,预测值: %s,真实值:%s' % (predict_results2,np.argmax(predict_results2),np.argmax(y_test[0])))"
475 ]
476 },
477 {
478 "cell_type": "markdown",
479 "metadata": {},
480 "source": [
481 "在下方写出你观察到的结果,并进行分析。"
482 ]
483 },
484 {
485 "cell_type": "markdown",
486 "metadata": {},
487 "source": [
488 "答案:(在此处填写你的答案。)"
489 ]
490 }
491 ],
492 "metadata": {
493 "kernelspec": {
494 "display_name": "Python 3",
495 "language": "python",
496 "name": "python3"
497 },
498 "language_info": {
499 "codemirror_mode": {
500 "name": "ipython",
501 "version": 3
502 },
503 "file_extension": ".py",
504 "mimetype": "text/x-python",
505 "name": "python",
506 "nbconvert_exporter": "python",
507 "pygments_lexer": "ipython3",
508 "version": "3.5.2"
509 }
510 },
511 "nbformat": 4,
512 "nbformat_minor": 2
513 }
0 {
1 "cells": [
2 {
3 "cell_type": "markdown",
4 "metadata": {},
5 "source": [
6 "# 5.5 神经网络学习"
7 ]
8 },
9 {
10 "cell_type": "markdown",
11 "metadata": {},
12 "source": [
13 "神经网络模拟人脑神经元的连接来达到学习功能,通过逐层抽象将输入数据逐层映射为概念等高等语义。"
14 ]
15 },
16 {
17 "cell_type": "markdown",
18 "metadata": {},
19 "source": [
20 "## 5.5.1 人脑神经机制"
21 ]
22 },
23 {
24 "cell_type": "markdown",
25 "metadata": {},
26 "source": [
27 "人眼在辨识图片时,会先提取边缘特征,再识别部件,最后再得到最高层的模式。也就是说,高层的特征是低层特征的组合,从低层到高层的特征表示越来越抽象,越来越能表现语义。"
28 ]
29 },
30 {
31 "cell_type": "markdown",
32 "metadata": {},
33 "source": [
34 "<img src=\"http://imgbed.momodel.cn//20200103102429.png\" width=300>"
35 ]
36 },
37 {
38 "cell_type": "markdown",
39 "metadata": {},
40 "source": [
41 "## 5.5.2 感知机模型"
42 ]
43 },
44 {
45 "cell_type": "markdown",
46 "metadata": {},
47 "source": [
48 "**感知机模型**:\n",
49 "\n",
50 "<img src=\"http://imgbed.momodel.cn/感知器模型.png\" width=300/>"
51 ]
52 },
53 {
54 "cell_type": "markdown",
55 "metadata": {},
56 "source": [
57 "**输入项**:3个,$x_1,x_2,x_3$ \n",
58 "**神经元**:1个,用圆圈表示 \n",
59 "**权重**:每个输入项均通过权重与神经元相连(比如 $w_i$ 是 $x_i$ 与神经元相连的权重) \n",
60 "**输出**:1个\n",
61 "\n",
62 "\n",
63 "**工作方法**:\n",
64 "+ 计算输入项传递给神经元的信息加权总和,即:$y_{sum} = w_1x_1+w_2x_2+w_3x_3$\n",
65 "+ 如果 $y_{sum}$ 大于某个预定阀值(比如 0.5),则输出为 1,否则为 0 。\n"
66 ]
67 },
68 {
69 "cell_type": "markdown",
70 "metadata": {},
71 "source": [
72 "在输出的判断上,其实不仅可以简单的按照阈值来判断,可以通过一个函数来进行计算,这个函数称为激活函数。常见的激活函数有: sigmoid,tanh,relu 等。下面我们看看这些激活函数的曲线图。"
73 ]
74 },
75 {
76 "cell_type": "code",
77 "execution_count": null,
78 "metadata": {},
79 "outputs": [],
80 "source": [
81 "import numpy as np\n",
82 "import matplotlib.pyplot as plt\n",
83 "import warnings\n",
84 "warnings.filterwarnings(\"ignore\")\n",
85 "\n",
86 "\n",
87 "def plot_activation_function(activation_function):\n",
88 " \"\"\"\n",
89 " 绘制激活函数\n",
90 " :param activation_function: 激活函数名\n",
91 " :return:\n",
92 " \"\"\"\n",
93 " x = np.arange(-10, 10, 0.1)\n",
94 " y_activation_function = activation_function(x)\n",
95 "\n",
96 " # 绘制坐标轴\n",
97 " ax = plt.gca()\n",
98 " ax.spines['right'].set_color('none')\n",
99 " ax.spines['top'].set_color('none')\n",
100 " ax.xaxis.set_ticks_position('bottom')\n",
101 " ax.yaxis.set_ticks_position('left')\n",
102 " ax.spines['bottom'].set_position(('data', 0))\n",
103 " ax.spines['left'].set_position(('data', 0))\n",
104 "\n",
105 " # 绘制曲线图\n",
106 " plt.plot(x, y_activation_function)\n",
107 " \n",
108 " # 展示函数图像\n",
109 " plt.show()"
110 ]
111 },
112 {
113 "cell_type": "code",
114 "execution_count": null,
115 "metadata": {},
116 "outputs": [],
117 "source": [
118 "def sigmoid(x):\n",
119 " \"\"\"\n",
120 " sigmoid函数\n",
121 " :param x: np.array 格式数据\n",
122 " :return: sigmoid 函数\n",
123 " \"\"\"\n",
124 " return 1 / (1 + np.exp(-x))\n",
125 "\n",
126 "# 绘制 sigmoid 函数图像\n",
127 "plot_activation_function(sigmoid)"
128 ]
129 },
130 {
131 "cell_type": "code",
132 "execution_count": null,
133 "metadata": {},
134 "outputs": [],
135 "source": [
136 "def tanh(x):\n",
137 " \"\"\"\n",
138 " tanh函数\n",
139 " :param x: np.array 格式数据\n",
140 " :return: tanh 函数\n",
141 " \"\"\"\n",
142 " return ((np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x)))\n",
143 "\n",
144 "# 绘制 tanh 函数图像\n",
145 "plot_activation_function(tanh)"
146 ]
147 },
148 {
149 "cell_type": "code",
150 "execution_count": null,
151 "metadata": {},
152 "outputs": [],
153 "source": [
154 "def relu(x):\n",
155 " \"\"\"\n",
156 " relu 函数\n",
157 " :param x: np.array 格式数据\n",
158 " :return: relu 函数\n",
159 " \"\"\"\n",
160 " temp = np.zeros_like(x)\n",
161 " if_bigger_zero = (x > temp)\n",
162 " return x * if_bigger_zero\n",
163 "\n",
164 "# 绘制 relu 函数\n",
165 "plot_activation_function(relu)"
166 ]
167 },
168 {
169 "cell_type": "markdown",
170 "metadata": {},
171 "source": [
172 "我们根据上面的定义可以编写一个简单的感知机模型。"
173 ]
174 },
175 {
176 "cell_type": "code",
177 "execution_count": null,
178 "metadata": {},
179 "outputs": [],
180 "source": [
181 "def perceptron(x, w, threshold):\n",
182 " \"\"\"\n",
183 " 感知机模型\n",
184 " :param x: 输入数据 np.array 格式\n",
185 " :param w: 权重 np.array 格式,需要与 x 一一对应\n",
186 " :param threshold: 阀值\n",
187 " :return: 0或者1\n",
188 " \"\"\"\n",
189 " x = np.array(x)\n",
190 " w = np.array(w)\n",
191 " y_sum = np.sum(w * x)\n",
192 " # 大于阀值返回 1,否则返回 0\n",
193 " return 1 if y_sum > threshold else 0\n",
194 "\n",
195 "\n",
196 "# 输入数据\n",
197 "x = np.array([1, 1, 4])\n",
198 "# 输入权重\n",
199 "w = np.array([0.5, 0.2, 0.3])\n",
200 "# 返回结果\n",
201 "perceptron(x, w, 0.8)"
202 ]
203 },
204 {
205 "cell_type": "markdown",
206 "metadata": {},
207 "source": [
208 "## 5.5.3 神经网络"
209 ]
210 },
211 {
212 "cell_type": "markdown",
213 "metadata": {},
214 "source": [
215 "<img src=\"http://imgbed.momodel.cn//20200103111837.png\" width=400>\n",
216 "\n",
217 "与感知机的不同,神经网络:\n",
218 "+ 输入层和输出层之间存在若干隐藏层。\n",
219 "+ 每个隐藏层中包含若干神经元。\n"
220 ]
221 },
222 {
223 "cell_type": "markdown",
224 "metadata": {},
225 "source": [
226 "## 5.5.4 搭建神经网络"
227 ]
228 },
229 {
230 "cell_type": "markdown",
231 "metadata": {},
232 "source": [
233 "采用 **keras** 框架搭建一个神经网络实现手写体数字识别问题。 \n",
234 "1. 导入相关包"
235 ]
236 },
237 {
238 "cell_type": "code",
239 "execution_count": null,
240 "metadata": {},
241 "outputs": [],
242 "source": [
243 "import keras\n",
244 "from keras.datasets import mnist\n",
245 "from keras.models import Sequential\n",
246 "from keras.layers.core import Dense,Activation,Dropout\n",
247 "from keras.utils import np_utils\n",
248 "import warnings\n",
249 "warnings.filterwarnings(\"ignore\")\n",
250 "!mkdir -p ~/.keras/datasets\n",
251 "!cp ./mnist.npz ~/.keras/datasets/mnist.npz"
252 ]
253 },
254 {
255 "cell_type": "markdown",
256 "metadata": {},
257 "source": [
258 "2. 下载 **MNIST** 数据集并将它们转换为模型所能使用的格式。"
259 ]
260 },
261 {
262 "cell_type": "code",
263 "execution_count": null,
264 "metadata": {},
265 "outputs": [],
266 "source": [
267 "# 获取数据\n",
268 "(X_train, y_train),(X_test,y_test) = mnist.load_data()\n",
269 "\n",
270 "# 将训练集数据形状从(60000,28,28)修改为(60000,784)\n",
271 "X_train = X_train.reshape(len(X_train),-1)\n",
272 "X_test = X_test.reshape(len(X_test),-1)\n",
273 "\n",
274 "# 将数据集图像像素点的数据类型从 uint8 修改为 float32\n",
275 "X_train = X_train.astype('float32')\n",
276 "X_test = X_test.astype('float32')\n",
277 "\n",
278 "# 把数据集图像的像素值从 0-255 放缩到[-1,1]之间\n",
279 "X_train = (X_train - 127)/127\n",
280 "X_test = (X_test - 127)/127\n",
281 "\n",
282 "# 数据集类别个数\n",
283 "nb_classes = 10\n",
284 "\n",
285 "# 把 y_train 和 y_test 变成了 one-hot 的形式,即之前是 0-9 的一个数值, \n",
286 "# 现在是一个大小为 10 的向量,它属于哪个数字,就在哪个位置为 1,其他位置都是 0。\n",
287 "y_train = np_utils.to_categorical(y_train,nb_classes)\n",
288 "y_test = np_utils.to_categorical(y_test,nb_classes)"
289 ]
290 },
291 {
292 "cell_type": "markdown",
293 "metadata": {},
294 "source": [
295 "3. 搭建神经网络模型"
296 ]
297 },
298 {
299 "cell_type": "code",
300 "execution_count": null,
301 "metadata": {},
302 "outputs": [],
303 "source": [
304 "def create_model():\n",
305 " \"\"\"\n",
306 " 采用 keras 搭建神经网络模型\n",
307 " :return: 神经网络模型\n",
308 " \"\"\"\n",
309 " # 选择模型,选择序贯模型(Sequential())\n",
310 " model = Sequential()\n",
311 " \n",
312 " # 添加全连接层,共 512 个神经元\n",
313 " model.add(Dense(512,input_shape=(784,),kernel_initializer='he_normal'))\n",
314 " \n",
315 " # 添加激活层,激活函数选择 relu \n",
316 " model.add(Activation('relu'))\n",
317 " \n",
318 " # 添加全连接层,共 512 个神经元\n",
319 " model.add(Dense(512,kernel_initializer='he_normal'))\n",
320 " \n",
321 " # 添加激活层,激活函数选择 relu \n",
322 " model.add(Activation('relu'))\n",
323 " \n",
324 " # 添加全连接层,共 10 个神经元\n",
325 " model.add(Dense(nb_classes))\n",
326 " \n",
327 " # 添加激活层,激活函数选择 softmax\n",
328 " model.add(Activation('softmax'))\n",
329 " \n",
330 " return model\n",
331 "\n",
332 "model = create_model()"
333 ]
334 },
335 {
336 "cell_type": "markdown",
337 "metadata": {},
338 "source": [
339 "4. 训练和测试神经网络模型"
340 ]
341 },
342 {
343 "cell_type": "code",
344 "execution_count": null,
345 "metadata": {},
346 "outputs": [],
347 "source": [
348 "def fit_and_predict(model, model_path):\n",
349 " \"\"\"\n",
350 " 训练模型、模型评估、保存模型\n",
351 " :param model: 搭建好的模型\n",
352 " :param model_path:保存模型路径\n",
353 " :return:\n",
354 " \"\"\"\n",
355 " # 编译模型\n",
356 " model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])\n",
357 " \n",
358 " # 模型训练\n",
359 " model.fit(X_train, y_train, epochs=5, batch_size=64, verbose=1, validation_split=0.05)\n",
360 " \n",
361 " # 保存模型\n",
362 " model.save(model_path)\n",
363 " \n",
364 " # 模型评估,获取测试集的损失值和准确率\n",
365 " loss, accuracy = model.evaluate(X_test, y_test)\n",
366 "\n",
367 " # 打印结果\n",
368 " print('Test loss:', loss)\n",
369 " print(\"Accuracy:\", accuracy)\n",
370 "\n",
371 "# 训练模型和评估模型\n",
372 "fit_and_predict(model, model_path='./model.h5')"
373 ]
374 },
375 {
376 "cell_type": "markdown",
377 "metadata": {},
378 "source": [
379 "### 实践与体验\n",
380 "#### 调节神经网络结构和参数\n",
381 "\n",
382 "1. 将两层隐藏层改为一层,训练模型并在测试集上测试,得出准确率。\n"
383 ]
384 },
385 {
386 "cell_type": "code",
387 "execution_count": null,
388 "metadata": {},
389 "outputs": [],
390 "source": [
391 "def create_model1():\n",
392 " \"\"\"\n",
393 " 搭建神经网络模型 model1,比 model 少一层隐藏层\n",
394 " :return: 模型 model1\n",
395 " \"\"\"\n",
396 " # 选择模型,选择序贯模型(Sequential())\n",
397 " model = Sequential()\n",
398 "\n",
399 " # 添加全连接层,共 512 个神经元\n",
400 " model.add(Dense(512, input_shape=(784,), kernel_initializer='he_normal'))\n",
401 "\n",
402 " # 添加激活层,激活函数选择 relu\n",
403 " model.add(Activation('relu'))\n",
404 "\n",
405 " # 添加全连接层,共 10 个神经元\n",
406 " model.add(Dense(nb_classes))\n",
407 "\n",
408 " # 添加激活层,激活函数选择 softmax\n",
409 " model.add(Activation('softmax'))\n",
410 "\n",
411 " return model\n",
412 "\n",
413 "# 搭建神经网络\n",
414 "model1 = create_model1()\n",
415 "\n",
416 "# 训练神经网络模型,保存模型和评估模型\n",
417 "fit_and_predict(model1, model_path='./model1.h5')"
418 ]
419 },
420 {
421 "cell_type": "markdown",
422 "metadata": {},
423 "source": [
424 "2. 修改两层隐藏层神经元的数量,然后训练模型得出准确率。"
425 ]
426 },
427 {
428 "cell_type": "code",
429 "execution_count": null,
430 "metadata": {},
431 "outputs": [],
432 "source": [
433 "def create_model2():\n",
434 " \"\"\"\n",
435 " 搭建神经网络模型 model2,隐藏层的神经元数目比 model 少一半\n",
436 " :return: 神经网络模型 model2\n",
437 " \"\"\"\n",
438 " # 选择模型,选择序贯模型(Sequential())\n",
439 " model = Sequential()\n",
440 "\n",
441 " # 添加全连接层,共 256 个神经元\n",
442 " model.add(Dense(256, input_shape=(784,), kernel_initializer='he_normal'))\n",
443 "\n",
444 " # 添加激活层,激活函数选择 relu\n",
445 " model.add(Activation('relu'))\n",
446 "\n",
447 " # 添加全连接层,共 256 个神经元\n",
448 " model.add(Dense(256, kernel_initializer='he_normal'))\n",
449 "\n",
450 " # 添加激活层,激活函数选择 relu\n",
451 " model.add(Activation('relu'))\n",
452 "\n",
453 " # 添加全连接层,共 10 个神经元\n",
454 " model.add(Dense(nb_classes))\n",
455 "\n",
456 " # 添加激活层,激活函数选择 softmax\n",
457 " model.add(Activation('softmax'))\n",
458 "\n",
459 " return model\n",
460 "\n",
461 "# 搭建神经网络模型\n",
462 "model2 = create_model2()\n",
463 "\n",
464 "# 训练神经网络模型,保存模型并评估模型\n",
465 "fit_and_predict(model2,model_path='./model2.h5')\n",
466 "\n"
467 ]
468 },
469 {
470 "cell_type": "markdown",
471 "metadata": {},
472 "source": [
473 "3. 输入一个手写数字,比较三种模型输出结果的差异,对其差异进行分析解释。"
474 ]
475 },
476 {
477 "cell_type": "code",
478 "execution_count": null,
479 "metadata": {},
480 "outputs": [],
481 "source": [
482 "import numpy as np\n",
483 "np.set_printoptions(suppress=True)\n",
484 "from keras.models import load_model\n",
485 "\n",
486 "# 加载模型\n",
487 "model = load_model('./model.h5')\n",
488 "model1 = load_model('./model1.h5')\n",
489 "model2 = load_model('./model2.h5') "
490 ]
491 },
492 {
493 "cell_type": "code",
494 "execution_count": null,
495 "metadata": {},
496 "outputs": [],
497 "source": [
498 "# 预测结果\n",
499 "predict_results = np.round(model.predict(X_test)[0],3)\n",
500 "predict_results1 = np.round(model1.predict(X_test)[0],3)\n",
501 "predict_results2 = np.round(model2.predict(X_test)[0],3)\n",
502 "\n",
503 "# 打印预测结果\n",
504 "print('原始模型\\n其各类别预测概率:%s,预测值: %s,真实值:%s\\n' % (predict_results,np.argmax(predict_results),np.argmax(y_test[0])))\n",
505 "print('只有一个隐藏层的模型\\n其各类别各类别预测概率:%s,预测值: %s,真实值:%s\\n' % (predict_results1,np.argmax(predict_results1),np.argmax(y_test[0])))\n",
506 "print('隐藏神经元数量更改后的模型\\n其各类别预测概率:%s,预测值: %s,真实值:%s' % (predict_results2,np.argmax(predict_results2),np.argmax(y_test[0])))"
507 ]
508 },
509 {
510 "cell_type": "code",
511 "execution_count": null,
512 "metadata": {},
513 "outputs": [],
514 "source": []
515 }
516 ],
517 "metadata": {
518 "kernelspec": {
519 "display_name": "Python 3",
520 "language": "python",
521 "name": "python3"
522 },
523 "language_info": {
524 "codemirror_mode": {
525 "name": "ipython",
526 "version": 3
527 },
528 "file_extension": ".py",
529 "mimetype": "text/x-python",
530 "name": "python",
531 "nbconvert_exporter": "python",
532 "pygments_lexer": "ipython3",
533 "version": "3.5.2"
534 }
535 },
536 "nbformat": 4,
537 "nbformat_minor": 2
538 }
Binary diff not shown
0 ## 介绍 (Introduction)
0 ## 介绍
11
2 添加该项目的功能、使用场景和输入输出参数等相关信息。
2 本课程主要教授 Python 基础知识, Python 进阶知识,机器学习常用的包以及人工智能相关的算法。
33
4 You can describe the function, usage and parameters of the project.
+0
-150
_README.ipynb less more
0 {
1 "cells": [
2 {
3 "cell_type": "markdown",
4 "metadata": {},
5 "source": [
6 "## 1. 项目介绍\n",
7 "\n",
8 " - 项目是由模块组成、有特定功能的程序。它能够满足用户的直接使用需求,例如[古诗词生成器](https://momodel.cn/explore/5bfb634e1afd943c623dd9cf?type=app&tab=1)、[风格迁移](https://momodel.cn/explore/5bfb634e1afd943c623dd9cf?type=app&tab=1)等。\n",
9 " - 开发项目过程中你可以导入数据集,也可以通过每个 cell 上方工具栏的`<+>`直接插入[模块](https://momodel.cn/modules)和代码块。\n",
10 " - 你可以将开发好的项目进行[部署](https://momodel.cn/docs/#/zh-cn/%E5%BC%80%E5%8F%91%E5%92%8C%E9%83%A8%E7%BD%B2%E4%B8%80%E4%B8%AA%E5%BA%94%E7%94%A8%EF%BC%88app%EF%BC%89),项目部署成功并选择正式版本发布后会展示在“项目”页面,用户可以在线使用,也可以通过 API 调用。\n",
11 "\n",
12 " - 项目目录结构:\n",
13 "\n",
14 " - ```results```*-----结果的文件存放地(如果你运行 job,务必将运行结果指定在此目录)*\n",
15 " - ```_OVERVIEW.md``` *-----项目的相关介绍*\n",
16 " - ```_README.md```*-----说明文档*\n",
17 " - ```app_spec.yml```*-----定义项目的输入输出,为部署服务*\n",
18 " - ```coding_here.ipynb```*-----输入并运行代码*"
19 ]
20 },
21 {
22 "cell_type": "markdown",
23 "metadata": {},
24 "source": [
25 "\n",
26 "## 2. 开发环境简介\n",
27 "\n",
28 "你当前所在的页面 Notebook 是一个内嵌 JupyterLab 的在线类 IDE 编程环境,开发过程中可以使用页面右侧的 API 文档进行快速查询。Notebook 有以下主要功能:\n",
29 "\n",
30 "- [调用数据集、模块和代码块资源](https://momodel.cn/docs/#/zh-cn/%E5%A6%82%E4%BD%95%E5%AF%BC%E5%85%A5%E5%B9%B6%E4%BD%BF%E7%94%A8%E6%A8%A1%E5%9D%97%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86)\n",
31 "- [多人代码协作](https://momodel.cn/docs/#/zh-cn/%E5%9C%A8Mo%E8%BF%90%E8%A1%8C%E4%BD%A0%E7%9A%84%E7%AC%AC%E4%B8%80%E6%AE%B5%E4%BB%A3%E7%A0%81?id=_7-%e4%bd%a0%e5%8f%af%e4%bb%a5%e9%82%80%e8%af%b7%e5%a5%bd%e5%8f%8b%e8%bf%9b%e8%a1%8c%e5%8d%8f%e4%bd%9c)\n",
32 "- [在 GPU 资源上训练机器学习模型](https://momodel.cn/docs/#/zh-cn/%E5%9C%A8GPU%E6%88%96CPU%E8%B5%84%E6%BA%90%E4%B8%8A%E8%AE%AD%E7%BB%83%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0%E6%A8%A1%E5%9E%8B)\n",
33 "- [简单部署](https://momodel.cn/docs/#/zh-cn/%E5%BC%80%E5%8F%91%E5%92%8C%E9%83%A8%E7%BD%B2%E4%B8%80%E4%B8%AA%E5%BA%94%E7%94%A8%EF%BC%88app%EF%BC%89)\n",
34 "\n",
35 "快来动手试试吧!点击左侧工具栏的新建文件图标即可选择你需要的文件类型。\n",
36 "\n",
37 "<img src='https://imgbed.momodel.cn/006tNc79gy1g61agfcv23j31c30u0789.jpg' width=100% height=100%>\n",
38 "\n",
39 "\n",
40 "\n",
41 "左侧和右侧工具栏都可根据使用需要进行收合。\n",
42 "<img src='https://imgbed.momodel.cn/collapse_tab.2019-09-06 11_07_44.gif' width=100% height=100%>"
43 ]
44 },
45 {
46 "cell_type": "markdown",
47 "metadata": {},
48 "source": [
49 "## 3. 快捷键与代码补全\n",
50 "Mo Notebook 已完全采用 Jupyter Notebook 的原生快捷键,并且支持 `tab` 代码补全。\n",
51 "\n",
52 "运行代码:`shift` + `enter` 或者 `shift` + `return`"
53 ]
54 },
55 {
56 "cell_type": "markdown",
57 "metadata": {},
58 "source": [
59 "## 4. 常用指令介绍\n",
60 "\n",
61 "- 解压上传后的文件\n",
62 "\n",
63 "在 cell 中输入并运行以下命令:\n",
64 "```!unzip -o file_name.zip```\n",
65 "\n",
66 "- 查看所有包(package)\n",
67 "\n",
68 "`!pip list --format=columns`\n",
69 "\n",
70 "- 检查是否已有某个包\n",
71 "\n",
72 "`!pip show package_name`\n",
73 "\n",
74 "- 安装缺失的包\n",
75 "\n",
76 "`!pip install package_name`\n",
77 "\n",
78 "- 更新已有的包\n",
79 "\n",
80 "`!pip install package_name --upgrade`\n",
81 "\n",
82 "\n",
83 "- 使用包\n",
84 "\n",
85 "`import package_name`\n",
86 "\n",
87 "- 显示当前目录下的档案及目录\n",
88 "\n",
89 "`ls`\n",
90 "\n",
91 "- 使用引入的数据集\n",
92 "\n",
93 "数据集被引入后存放在 datasets 文件夹下,注意,这个文件夹是只读的,不可修改。如果需要修改,可在 Notebook 中使用\n",
94 "\n",
95 "`!cp -R ./datasets/<imported_dataset_dir> ./<your_folder>`\n",
96 "\n",
97 "指令将其复制到其他文件夹后再编辑,对于引入的数据集中的 zip 文件,可使用\n",
98 "\n",
99 "`!unzip ./datasets/<imported_dataset_dir>/<XXX.zip> -d ./<your_folder>`\n",
100 "\n",
101 "指令解压缩到其他文件夹后使用"
102 ]
103 },
104 {
105 "cell_type": "markdown",
106 "metadata": {},
107 "source": [
108 "## 5. 其他可参考资源\n",
109 "- [帮助文档](https://momodel.cn/docs/#/):基本页面介绍和常见问题都可以在里面找到\n",
110 "- [平台功能教程](https://momodel.cn/classroom/class?id=5c5696cd1afd9458d456bf54&type=doc):通过图文结合的 Notebook 详细介绍开发环境基本功能和操作\n",
111 "- [吴恩达机器学习](https://momodel.cn/classroom/class?id=5c5696191afd94720cc94533&type=video):机器学习经典课程\n",
112 "- [李宏毅机器学习](https://s.momodel.cn/classroom/class?id=5d40fdafb5113408a8dbb4a1&type=video):中文世界最好的机器学习课程\n",
113 "- [机器学习实战](https://momodel.cn/classroom/class?id=5c680b311afd943a9f70901b&type=practice):通过实操指引完成独立的模型,掌握相应的机器学习知识\n",
114 "- [Python 教程](https://momodel.cn/classroom/class?id=5d1f3ab81afd940ab7d298bf&type=notebook):简单易懂的 Python 新手教程\n",
115 "- [模块开发](https://momodel.cn/modules):关于模型训练、开发与部署的高阶教程"
116 ]
117 }
118 ],
119 "metadata": {
120 "kernelspec": {
121 "display_name": "Python 3",
122 "language": "python",
123 "name": "python3"
124 },
125 "language_info": {
126 "codemirror_mode": {
127 "name": "ipython",
128 "version": 3
129 },
130 "file_extension": ".py",
131 "mimetype": "text/x-python",
132 "name": "python",
133 "nbconvert_exporter": "python",
134 "pygments_lexer": "ipython3",
135 "version": "3.5.2"
136 },
137 "pycharm": {
138 "stem_cell": {
139 "cell_type": "raw",
140 "source": [],
141 "metadata": {
142 "collapsed": false
143 }
144 }
145 }
146 },
147 "nbformat": 4,
148 "nbformat_minor": 2
149 }
+0
-36
coding_here.ipynb less more
0
1 {
2 "cells": [
3 {
4 "cell_type": "code",
5 "execution_count": null,
6 "metadata": {},
7 "outputs": [],
8 "source": [
9 "print('Hello Mo!')"
10 ]
11 }
12 ],
13 "metadata": {
14 "kernelspec": {
15 "display_name": "Python 3",
16 "language": "python",
17 "name": "python3"
18 },
19 "language_info": {
20 "codemirror_mode": {
21 "name": "ipython",
22 "version": 3
23 },
24 "file_extension": ".py",
25 "mimetype": "text/x-python",
26 "name": "python",
27 "nbconvert_exporter": "python",
28 "pygments_lexer": "ipython3",
29 "version": "3.5.2"
30 }
31 },
32 "nbformat": 4,
33 "nbformat_minor": 2
34 }
35
Binary diff not shown
0 为什么鲸鱼的体型如此庞大?有一个答案很简单,它们就是能长到这么大。
1 陆地动物体型受限的部分原因是因为它们需要依靠自身对抗地心引力。
2 海洋生物可以通过它们生活的环境为媒介免费获得这种支持。
3 即便如此,可能的事情并不总是明智的。投入生长的资源无法再生产。
4 鉴于鲸鱼可以且确实庞大,第二个问题就来了:如果有的话,有什么可以阻止它们变得更大?
5 斯坦福大学的Jeremy Goldbogen和他的同事怀疑两个问题的答案和动物的食物供应有关。
6 正如他们在《Science》杂志中发表的一篇论文所述,他们收集了数据,阐明了可能的运作方式。
7 总的来说,大鲸鱼有两种。齿鲸,如抹香鲸,捕食单个猎物。
8 须鲸,用纤维状的口腔过滤器吸入大量的水,并从磷虾等小生物中吸取营养。
9 所有鲸鱼中最大的鲸类(蓝鲸、座头鲸等)是须鲸。
10 这或许被认为是很矛盾的,因为在陆地上,捕食者越大,它们的猎物也越大。
11 齿鲸和须鲸通常都潜入深海捕食,因为深海里的猎物更多。
12 为了去深海捕食,它们会屏住呼吸,这就限制了它们在深水中停留的时间。
13 鲸鱼巨大体型的解释之一是因为体型越大,屏息时间越长,这样它们就能花更长时间在深海捕食。
0 Why are whales so big? One answer is simply that they can be.
1 The size of land animals is constrained in part by their need to support themselves against the force of gravity.
2 Marine creatures have that support provided free, by the medium they live in.
3 Even so, what is possible is not always sensible. Resources put into growth are unavailable for reproduction.
4 Given that whales can and do become big, however, a second question arises: what, if anything, stops them being even bigger?
5 Jeremy Goldbogen of Stanford University and his colleagues suspect that the answers to both questions are related to the animals' food supply.
6 And, as they describe in a paper in Science, they have gathered data that illuminate how this might work.
7 Broadly, big whales come in two varieties. Toothed whales, such as sperm whales, hunt individual prey.
8 Baleen whales suck in mouthfuls of water and extract small organisms such as krill, using fibrous buccal filters.
9 The biggest whales of all (blue, humpback and so on) are baleen whales.
10 This might be viewed as paradoxical, because on land, as predators get bigger, so do their individual prey.
11 Both toothed and baleen whales often hunt by diving deep—prey being more abundant at depth.
12 To do this they have to hold their breath, which limits how long they can stay underwater.
13 One explanation of giantism in whales is that because bigger whales can hold their breath longer, they can spend more time hunting.
0 人类所到之处都是一片狼藉,太空也不例外。
1 现在人类要为超过500000个绕地球高速飞行的垃圾负责,如果我们再不积极行动起来清理大型垃圾,碰撞的危险只会越来越大。
2 欧洲航天局局长Jan Wörner说:“想象一下如果曾经失踪的所有船只都漂浮在水面上,在公海航行会有多危险。”
3 “目前轨道上就是这种情况,不能再任其继续发展了。”
4 就好像我们需要一辆拖车把成千上万颗失灵的卫星拖离轨道。顺便说一句,这正是欧洲航天局正在研究的。
5 该机构计划到2025年发射世界上第一个轨道垃圾收集器,这个4条手臂的机器人就像迷宫里的吃豆子的人一样追踪太空垃圾。
6 第一次这样的任务被称为清理太空-1,先执行小任务,只收集单件的太空垃圾来证明这一概念的可行性。这一任务的目标是Vespa,是欧洲航天局2013年发射织女星火箭留下的残骸。这件垃圾的重量几乎相当于一颗小型卫星,形状简单,机器人的四只手臂应该能很容易地抓取。一旦被垃圾收集器安全抓取,就会被拖出轨道,在大气层中烧毁。
0 Wherever we humans go, we leave behind a mess. That goes for space, too.
1 Today, our species is responsible for more than 500,000 pieces of junk hurtling around Earth at phenomenal speeds, and if we don't start actively removing the largest pieces, the risk of collisions will only grow worse.
2 "Imagine how dangerous sailing the high seas would be if all the ships ever lost in history were still drifting on top of the water," says Jan Wörner, European Space Agency (ESA) director general.
3 "That is the current situation in orbit, and it cannot be allowed to continue."
4 It's almost as if we need a tow truck to remove all the thousands of failed satellites from our orbit; incidentally, that's exactly what the ESA is working on.
5 By 2025, the agency plans on launching the world's first orbiting junk collector, a four-armed robot that tracks down space waste like Pac-Man in a maze.
6 The first-of-its-kind mission, known as ClearSpace-1, will start out small, collecting only a single piece of space junk to prove the concept works. The target in this case is called Vespa, a leftover remnant from ESA's Vega rocket launch in 2013.
7 This piece of junk weighs roughly the same as a small satellite and has a simple shape that should make it easy to grab with four robotic arms. Once it's safely in the arms of the garbage collector, it will then be dragged out of orbit and allowed to burn up in the atmosphere.
0 澳大利亚当局表示,势无可挡的森林大火已经将一些房屋夷为平地,并迅速蔓延到悉尼郊区。新南威尔士州仍有数十处大火在燃烧,气温达到35摄氏度,风速达到每小时80公里。风向确如人们担心的那样转为了南风,不过大家所惧怕的“灾难性”一天基本上得以避免。
1 本周二(11月12日)没有人员死亡报告,但消防官员警告称,目前的情况意味着该州面临的危险远未结束。
2 大约有600万人居住在新南威尔士州。
3 据澳大利亚媒体报道,新南威尔士州目前仍有100至300起火灾。
4 消防人员一直在新南威尔士州北部绵延1000公里的前线作战,有关官员表示,几场大火“单是着火面积就超过10万公顷”。
5 一些森林大火蔓延至距市中心15公里以内地区,消防人员不得不在悉尼北部郊区投放阻燃剂。
6 大火距离布莱克·海门和肖恩·墨菲家所在的郊区富人区只有几米远。
7 大火从居民家门前马路对面的茂密丛林蔓延开来,居民们被迫用游泳池的水灭火。
8 海门告诉《悉尼先驱晨报》说:“实际情况是,我们的水压不够了,所以不得不用游泳池的水。肖恩用其中一个水桶接水扑灭了一处着火点。”
9 当地政府表示,一名消防员手臂骨折,肋骨可能也出现骨折。
10 农业消防局局长谢恩·菲茨西蒙斯说:“我们还有很长的路要走。在下一轮恶劣天气来临之前,我们肯定无法扑灭所有这些火灾。”
11 “不幸的是,情况没有任何缓解的迹象。没有降雨,在未来几天至几周内,天气仍将炎热干燥。”
12 菲茨西蒙斯说,本周二(11月12日)有多达12座房屋被损坏或摧毁。
13 新南威尔士州受影响社区的人们被敦促远离丛林地带。全州有600多所学校停课。
14 澳大利亚保守派政府拒绝透露气候变化是否可能导致了火灾,这一回应招致了批评。
15 自上周五(11月8日)新南威尔士州的火灾紧急情况加剧以来,已有3人死亡,170多座房屋被毁。
16 当地政府表示,他们面临的可能是“本国有史以来最危险的森林大火周”。
17 菲茨西蒙斯表示,前线有3000名消防员,还有来自其他州、新西兰以及澳大利亚国防军的支援。
18 专家们称这一情况堪比2009年维多利亚州“黑色星期六”的森林大火,那次火灾造成173人死亡。
19 有报告称,新南威尔士州猎人区北罗斯伯里的一场火灾以及至少另外两场火灾可能是蓄意纵火,警方正在对此进行调查。
20 消防部门表示,自9月份火灾季节开始以来,新南威尔士州已有100万公顷土地遭遇火情。
21 北部的昆士兰州也宣布进入紧急状态,该州有55起森林大火。
22 尽管本周二(11月12日)昆州没有遭遇如此恶劣的天气,但官员们警告称,本周晚些时候情况可能会恶化。
23 南澳大利亚州的消防人员正在扑灭十几起大火,而西澳大利亚州的森林大火也引发了紧急警报。
24 科学家和专家警告说,由于气候变化,澳大利亚的火灾季节变得更长、更严重。
25 官员们已经证实,2018年和2017年分别是澳大利亚有记录以来第三和第四最热的年份,去年澳大利亚经历了有记录以来最热的夏天。
26 气象局发布的《2018年气候状况报告》称,气候变化导致极端高温事件增加,干旱等其他自然灾害的严重程度也随之提高。
27 2015年,188个国家签署了具有里程碑意义的《巴黎协定》,规定与前工业化时期相比,全球气温升幅不得超过2摄氏度。即便如此,科学家们仍认为,澳大利亚正面临危险的新常态。
28 去年,一份联合国报告称,澳大利亚在减少二氧化碳排放方面做得不够。
0 Raging bushfires have razed properties in Australia and briefly spread to suburbs of Sydney, officials say.
1
2 Scores of fires are still burning in New South Wales amid temperatures of 35C and winds of 80km/h.
3 A feared southerly wind change has now occurred but the "catastrophic" day feared has largely been avoided.
4 No deaths were reported on Tuesday but fire chiefs warned that conditions meant the dangers facing the state were far from over.
5 About six million people live in New South Wales (NSW) state.
6 The number of fires still afflicting NSW ranged from 100 to 300 in Australian media reports.
7 Crews have been battling a front spanning 1,000km along the north coast of NSW, with several blazes "exceeding 100,000 hectares alone", officials have said.
8 Flame retardant had to be dropped in Sydney's northern suburbs as some bushfires approached within 15km of the city centre.
9 Flames came within metres of engulfing the homes of Blake Haymen and Sean Murphy in the affluent suburb.
10 The residents were forced to use pool water to tackle a blaze that had spread from dense bushland across the road from their homes.
11 "We actually ran out of water pressure, so we had to go to the pool. Sean put out a spot-fire with one of these buckets," Mr Haymen told the Sydney Morning Herald.
12 Authorities said one firefighter had suffered a broken arm and suspected fractured ribs.
13 Rural Fire Service Commissioner Shane Fitzsimmons said: "We've really got a long way to go. You can guarantee we're not going to be able to get around all of these fires before the next wave of bad weather.
14 "Unfortunately there's no meaningful reprieve. There's no rainfall in this change and we're going to continue to have warm dry conditions dominating in the days and weeks ahead."
15 Commissioner Fitzsimmons said up to a dozen homes were believed to have been damaged or destroyed on Tuesday.
16 People in vulnerable NSW communities have been urged to stay away from bushland. More than 600 schools are closed across the state.
17 Australia's conservative government has refused to be drawn on whether climate change could have contributed to the fires, in a response that has drawn criticism.
18
19
20 Three people have died and more than 170 properties have been destroyed since the fire emergency intensified in NSW on Friday.
21
22 Authorities had said they were facing what could be "the most dangerous bushfire week this nation has ever seen".
23 Mr Fitzsimmons said 3,000 firefighters were on the front lines, boosted by crews from other states and New Zealand, as well as the Australian Defence Force.
24 Experts have compared the situation to the 2009 Black Saturday bushfires in Victoria, when 173 people died.
25 Police are investigating reports that a fire in North Rothbury, in the Hunter Region of NSW, and at least two others, may have been started deliberately.
26 Fire authorities say a million hectares of land have burned in NSW since the fire season began in September.
27 To the north, Queensland has also declared a state of emergency as 55 bushfires rage in the state.
28 Though it was not facing such severe weather on Tuesday, officials warned conditions could deteriorate later in the week.
29 Fire crews in South Australia were tacking about a dozen blazes, while bushfires in Western Australia also sparked emergency warnings.
30 Scientists and experts warn that Australia's fire season has grown longer and more intense due to climate change.
31 Officials have confirmed that 2018 and 2017 were Australia's third and fourth-hottest years on record respectively, and last year the nation experienced its warmest summer on record.
32 The Bureau of Meteorology's State of the Climate 2018 report said climate change had led to an increase in extreme heat events and raised the severity of other natural disasters, such as drought.
33 Even if global temperatures are contained to a 2C rise above pre-industrial levels - a limit set out in the landmark Paris accord, agreed by 188 nations in 2015 - scientists believe the country is facing a dangerous new normal.
34 Last year, a UN report said Australia was falling short in efforts to cut its CO2 emissions.
Binary diff not shown
+0
-2
results/README.md less more
0 Please store your training checkpoints or results here
1 请在此处存储 checkpoints 和结果文件
+0
-2
results/tb_results/README.md less more
0 Please store your tensorboard results here
1 请在此处存储 tensorboard 结果
0 import matplotlib.pyplot as plt
1 import collections
2 from IPython import display
3 import networkx as nx
4 import numpy as np
5 import time
6
7
8 class SearchGraph():
9 def __init__(self,
10 node_list,
11 weighted_edges_list,
12 start_node,
13 target_node,
14 nodes_pos=None,
15 help_info=None,):
16 self.node_list = node_list
17 self.weighted_edges_list = weighted_edges_list
18 self.start_node = start_node
19 self.target_node = target_node
20 self.nodes_pos = nodes_pos
21 self.max_depth=len(node_list)
22 self.temp_best_path = None
23
24 self.weighted_edges_dic = {frozenset([e[0],e[1]]):e[2] for e in weighted_edges_list}
25 self.help_info = help_info
26 self.path_score={self.start_node:0}
27
28 self.animation_type = 'dfs'
29
30 self.basic_node_color = '#6CB6FF'
31 self.start_node_color = 'y'
32 self.target_node_color = 'r'
33 self.visited_node_color = 'g'
34
35 self.basic_edge_color = 'b'
36 self.visited_edge_color = 'g'
37
38 self.success_color = 'r'
39
40 self.correct_paths={}
41 self.show_correct_path = []
42 self.build_graph()
43 self.get_search_tree_node_position()
44 self.bfs_search()
45
46
47
48 def build_graph(self):
49 self.G = nx.Graph()
50 self.G.add_nodes_from(self.node_list)
51 self.G.add_weighted_edges_from(self.weighted_edges_list)
52
53 def get_search_tree_node_position(self):
54 """得到绘图的点的坐标
55 """
56 self.dfs_search()
57 # 得到 dfs 的搜索路径图
58 paths = self.dfs_path
59 # 得到每条路径的子路径
60 path_childern = {}
61 for path in paths:
62 father = path[:-1]
63 if father in paths:
64 if father in path_childern:
65 path_childern[father].append(path)
66 else:
67 path_childern[father] = [path]
68 # 对每条子路径排序
69 o_path_childern = collections.OrderedDict(sorted(path_childern.items()))
70 # 计算每个树图中每个节点的位置
71 tree_node_position = {self.start_node:(1, 0, 2)}
72 for path, sub_paths in o_path_childern.items():
73 y_pos = -0.5 * len(path)
74 dx = tree_node_position[path][2]/len(sub_paths)
75 sub_paths.sort()
76 for index, e_s in enumerate(sub_paths):
77 x_pos = tree_node_position[path][0] - tree_node_position[path][2]/2 + dx/2 + dx*index
78 tree_node_position[e_s]=(x_pos,y_pos, dx)
79 self.tree_node_position = tree_node_position
80
81 def show_edge_labels(self, ax, pos1, pos2, label):
82 (x1, y1) = pos1
83 (x2, y2) = pos2
84 (x, y) = (x1*0.5 + x2*0.5, y1*0.5 + y2*0.5)
85
86 angle = np.arctan2(y2 - y1, x2 - x1) / (2.0 * np.pi) * 360
87 if angle > 90:
88 angle -= 180
89 if angle < - 90:
90 angle += 180
91 xy = np.array((x, y))
92 trans_angle = ax.transData.transform_angles(np.array((angle,)),
93 xy.reshape((1, 2)))[0]
94 bbox = dict(boxstyle='round',
95 ec=(1.0, 1.0, 1.0),
96 fc=(1.0, 1.0, 1.0),
97 )
98 label = str(label)
99 ax.text(x, y,
100 label,
101 size=16,
102 color='k',
103 alpha=1,
104 horizontalalignment='center',
105 verticalalignment='center',
106 rotation=trans_angle,
107 transform=ax.transData,
108 bbox=bbox,
109 zorder=1,
110 clip_on=True,
111 )
112
113 def show_search_tree(self,
114 this_path=None,
115 show_success_color=False,
116 best_path=None
117 ):
118 """展示搜索树
119 """
120 # 画出树图
121 fig, ax = plt.subplots()
122 fig.set_figwidth(15)
123 fig.set_figheight(10)
124 plt.axis('off')
125
126 for path, pos in self.tree_node_position.items():
127 if path[-1] == self.start_node:
128 node_color = self.start_node_color
129 edge_color = self.basic_edge_color
130 elif this_path and path in this_path:
131 if show_success_color:
132 node_color = self.success_color
133 edge_color = self.success_color
134 else:
135 node_color = self.visited_node_color
136 edge_color = self.visited_edge_color
137 elif path[-1] == self.target_node:
138 node_color = self.target_node_color
139 edge_color = self.basic_edge_color
140 else:
141 node_color = self.basic_node_color
142 edge_color = self.basic_edge_color
143 ax.scatter(pos[0], pos[1], c=node_color, s=1000,zorder=1)
144 plt.annotate(
145 path[-1],
146 xy=(pos[0], pos[1]),
147 xytext=(0, 0),
148 textcoords='offset points',
149 ha='center',
150 va='center',
151 size=15,)
152 if len(path)>1:
153 plt.plot([self.tree_node_position[path[:-1]][0],pos[0]],
154 [self.tree_node_position[path[:-1]][1],pos[1]],
155 color=edge_color,
156 zorder=0)
157 if len(path)>1:
158 label = self.weighted_edges_dic[frozenset([path[-2],path[-1]])]
159 if self.animation_type in ['greedy','a_star']:
160 label = self.help_info_weight*self.help_info[path[-1]] + self.origin_info_weight*label
161 self.show_edge_labels(ax, self.tree_node_position[path[:-1]][0:2], pos[0:2], label)
162 display.clear_output(wait=True)
163
164 show_res_text = ""
165 for e_c in self.show_correct_path:
166 show_res_text += '找到一条路径: %-7s' % e_c + '。距离为:' +str(self.correct_paths[e_c]) + '\n'
167 plt.text(0, -3.3, show_res_text, fontsize=18,horizontalalignment='left', verticalalignment='top',)
168
169 if best_path:
170 top_text = '最终最短路径为: %-7s' % this_path + '。距离为:' +str(self.correct_paths[this_path]) + '\n'
171 elif this_path and self.animation_type in ['dfs','bfs']:
172 top_text = '当前路径: %-7s' % this_path + '。距离为:' +str(self.path_score[this_path]) + '\n'
173 if self.temp_best_path:
174 top_text += '当前最短路径为: %-7s' % self.temp_best_path + '。距离为:' +str(self.correct_paths[self.temp_best_path]) + '\n'
175 else:
176 top_text = ''
177
178 plt.text(0, 0,
179 top_text,
180 fontsize=18,
181 horizontalalignment='left',
182 verticalalignment='top',)
183
184 if self.animation_type in ['greedy','a_star']:
185 show_greedy_text = self.generate_greedy_help_text(this_path)
186 plt.text(0, 0, show_greedy_text, fontsize=18, horizontalalignment='left', verticalalignment='top',)
187 plt.show()
188
189 def animation_search_tree(self,search_method='dfs', help_info_weight=1, origin_info_weight=1):
190 """动画展示搜索过程
191 """
192 self.animation_type = search_method
193 self.show_correct_path = []
194 self.temp_best_path = None
195 if search_method == 'bfs':
196 paths = self.bfs_path
197 elif search_method == 'dfs':
198 paths = self.dfs_path
199 elif search_method == 'greedy':
200 self.greedy_search()
201 paths = self.greedy_search_path
202 elif search_method == 'a_star':
203 self.a_star_search(help_info_weight=help_info_weight, origin_info_weight=origin_info_weight)
204 paths = self.greedy_search_path
205 else:
206 paths = []
207 for e_path in paths:
208 self.show_search_tree(e_path)
209 if e_path in self.correct_paths:
210 if not self.temp_best_path:
211 self.temp_best_path = e_path
212 elif self.path_score[e_path] < self.path_score[self.temp_best_path]:
213 self.temp_best_path = e_path
214 self.show_correct_path.append(e_path)
215 self.show_search_tree(e_path, True)
216 if search_method in ['greedy', 'a_star']:
217 time.sleep(5)
218 if search_method in ['bfs', 'dfs']:
219 best_path = min(self.correct_paths, key=self.correct_paths.get)
220 self.show_search_tree(best_path, True, True)
221
222 def animation_graph(self, search_method='bfs', help_info_weight=1, origin_info_weight=1):
223
224 """
225 """
226 self.animation_type = search_method
227 self.show_correct_path = []
228 if search_method == 'bfs':
229 paths = self.bfs_path
230 elif search_method == 'dfs':
231 paths = self.dfs_path
232 elif search_method == 'greedy':
233 self.greedy_search()
234 paths = self.greedy_search_path
235 elif search_method == 'a_star':
236 self.a_star_search(help_info_weight=help_info_weight, origin_info_weight=origin_info_weight)
237 paths = self.greedy_search_path
238 else:
239 paths = []
240 for e_path in paths:
241 self.show_graph(e_path)
242 if e_path in self.correct_paths:
243 self.show_correct_path.append(e_path)
244 self.show_graph(e_path, True)
245 time.sleep(5)
246 if search_method in ['bfs', 'dfs']:
247 best_path = min(self.correct_paths, key=self.correct_paths.get)
248 self.show_graph(best_path, True, True)
249
250 def show_graph(self, this_path='',
251 show_success_color=False,
252 best_path=None):
253 """
254 绘制图
255 :return:
256 """
257 fig, ax = plt.subplots()
258 fig.set_figwidth(6)
259 fig.set_figheight(8)
260 plt.axis('off')
261
262 # 绘制节点与边颜色
263 visited_edges = []
264 if not this_path:
265 this_path = self.start_node
266 path_node_list = list(this_path)
267 for i in range(1,len(path_node_list)):
268 visited_edges.append(frozenset([path_node_list[i],path_node_list[i-1]]))
269
270 # 节点与标识
271 nlabels = dict(zip(self.node_list, self.node_list))
272 edge_labels = dict([((u, v,), d['weight']) for u, v, d in self.G.edges(data=True)])
273
274 # 节点颜色变化
275 val_map = {self.target_node: self.target_node_color}
276 if path_node_list:
277 for i in path_node_list:
278 if show_success_color:
279 val_map[i] = self.success_color
280 else:
281 val_map[i] = self.visited_node_color
282 val_map[self.start_node] = self.start_node_color
283 values = [val_map.get(node, self.basic_node_color) for node in self.G.nodes()]
284
285 # 处理边的颜色
286 edge_colors = []
287 for edge in self.G.edges():
288 # 如果边在result_red_edges,分2种情况:
289 # 如果this_path[0]/this_path[-1] 对应起始点和终点,颜色为绿色,否则颜色为红色
290 # 如果边不在result_red_edges,则初始化边的颜色为黑色
291 if frozenset(edge) in visited_edges:
292 if show_success_color:
293 edge_colors.append(self.success_color)
294 else:
295 edge_colors.append(self.visited_edge_color)
296 else:
297 edge_colors.append(self.basic_edge_color)
298
299 # 绘制节点及其标签
300 nx.draw_networkx_nodes(self.G, self.nodes_pos, node_size=800, node_color=values, width=6.0)
301 nx.draw_networkx_labels(self.G, self.nodes_pos, nlabels, font_size=20)
302 # 绘制边及其标签
303 nx.draw_networkx_edges(self.G, self.nodes_pos, edge_color=edge_colors, width=2.0, alpha=1.0)
304 nx.draw_networkx_edge_labels(self.G, self.nodes_pos, edge_labels=edge_labels, font_size=18)
305
306 display.clear_output(wait=True)
307 # show_text = ""
308 # for e_c in self.show_correct_path:
309 # show_text += '找到一条路径: %-7s' % e_c + '。距离为:' +str(self.correct_paths[e_c]) + '\n'
310 # plt.text(0, -2.6, show_text, fontsize=18, horizontalalignment='left', verticalalignment='top', )
311
312 # if best_path:
313 # top_text = '最佳路径为: %-7s' % this_path + '。 距离为:' +str(self.correct_paths[this_path]) + '\n'
314 # elif this_path and self.animation_type in ['dfs','bfs']:
315 # top_text = '当前路径: %-7s' % this_path + '。 距离为:' +str(self.cal_dis(this_path)) + '\n'
316 # else:
317 # top_text = ''
318 # plt.text(0, 0,
319 # top_text,
320 # fontsize=18,
321 # horizontalalignment='left',
322 # verticalalignment='top',)
323 plt.show()
324
325 def _dfs_helper(self, G, node, father, target_node,level, res, path):
326 path+=str(node)
327 if len(path)>1:
328 self.path_score[path] = self.path_score[path[:-1]] + self.weighted_edges_dic[frozenset([path[-2],path[-1]])]
329 res.append(path)
330 # 找到目标,停止搜索
331 if node==target_node:
332 return
333 if level< self.max_depth:
334 for neighbor in sorted(G[node]):
335 if str(neighbor) not in path:
336 self._dfs_helper(G, neighbor, node, target_node, level+1, res, path)
337
338 def dfs_search(self):
339 dfs_path=[]
340 this_path=''
341 if self.start_node:
342 self._dfs_helper(self.G, self.start_node, None, self.target_node, 0, dfs_path, this_path)
343 self.dfs_path = dfs_path
344 for p in dfs_path:
345 if p[-1]==self.target_node and p not in self.correct_paths:
346 self.correct_paths[p] = self.cal_dis(p)
347
348 def bfs_search(self):
349 to_search=[self.start_node]
350 bfs_path = []
351 bfs_correct_path = []
352 depth = 0
353 while to_search:
354 this_search = to_search.pop(0)
355 if len(this_search)>self.max_depth+1 :
356 break
357 bfs_path.append(this_search)
358 if this_search[-1]==self.target_node:
359 bfs_correct_path.append(this_search)
360 continue
361 for ne in sorted(self.G[this_search[-1]]):
362 if ne not in this_search:
363 to_search.append(this_search+ne)
364 self.bfs_path = bfs_path
365 for p in bfs_path:
366 if p[-1]==self.target_node and p not in self.correct_paths:
367 self.correct_paths[p] = self.cal_dis(p)
368
369 def greedy_search(self, help_info_weight=1, origin_info_weight=0):
370 self.help_info_weight = help_info_weight
371 self.origin_info_weight = origin_info_weight
372 search_path = self.start_node
373 # 存储每一步的可选项及其分数,用来在动态演示时显示出来
374 search_scores = {}
375 while len(search_path) <= self.max_depth:
376 this_node = search_path[-1]
377 neighbour_nodes = [e_n for e_n in sorted(self.G[this_node]) if e_n not in search_path]
378 if len(neighbour_nodes) == 0:
379 search_scores[search_path]={}
380 break
381 if self.help_info:
382 scores = {e_n:help_info_weight*self.help_info[e_n]+origin_info_weight*self.weighted_edges_dic[frozenset([this_node,e_n])] for e_n in neighbour_nodes }
383 else:
384 scores = {e_n:self.weighted_edges_dic[frozenset([this_node,e_n])]
385 for e_n in neighbour_nodes }
386 search_scores[search_path]=scores
387 nearest_node = min(scores, key=scores.get)
388 search_path += nearest_node
389 if nearest_node == self.target_node:
390 break
391 self.greedy_search_path = [search_path[0:index+1] for index in range(len(search_path))]
392 self.search_scores = search_scores
393
394 def a_star_search(self, help_info_weight=1, origin_info_weight=1):
395 self.greedy_search(help_info_weight, origin_info_weight)
396
397
398 def generate_greedy_help_text(self,path):
399 if path[-1] == self.target_node:
400 return '抵达目标节点' + str(self.target_node)
401
402 base_text = '当前可选的子节点及其信息值为 \n'+ \
403 str(self.search_scores[path]) + '\n'
404 if self.target_node in self.search_scores[path]:
405 return base_text + '当前可选的子节点包含了目标节点,\n所以选择目标节点'
406 elif len(self.search_scores[path]) == 1:
407 return base_text + '因为只有一个子节点,所以选择此节点'
408 else:
409 return base_text + '因为'+ \
410 str(min(self.search_scores[path], key=self.search_scores[path].get)) + \
411 '的值最小,所以选择此节点'
412
413 def cal_dis(self,path):
414 dis = 0
415 if len(path) > 1:
416 for i in range(len(path)-1):
417 dis += self.weighted_edges_dic[frozenset([path[i],path[i+1]])]
418 return dis
419