{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import networkx as nx"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"def bfs_search(G, max_depth, start_node, target_node):\n",
" # 待访问的路径\n",
" to_search = [(start_node, 0)]\n",
" # 存储所有的历史路径,及此路径的距离\n",
" bfs_path = []\n",
" # 正确的路径列表,及此路径的距离\n",
" bfs_correct_path = []\n",
" # 当还有待访问的路径时\n",
" while to_search:\n",
" # 从待访问的路径中取第一个待访问路径及其路径长度,例如 AC\n",
" this_path, this_path_dis = to_search.pop(0)\n",
" # 如果待访问的路径达到最大搜索深度,跳出循环\n",
" if len(this_path) > max_depth :\n",
" break\n",
" # 把刚取出的路径存入历史路径中\n",
" bfs_path.append((this_path, this_path_dis))\n",
" # 如果路径的最后一个节点是目标节点,路径 AC 的最后一个节点是 C\n",
" if this_path[-1] == target_node:\n",
" # 其为一条正确的路径,将存入正确的路径列表中,\n",
" # 并不再继续往其子节点进行探索\n",
" bfs_correct_path.append((this_path, this_path_dis))\n",
" continue\n",
" # 找到路径最后一个节点的相邻节点\n",
" for ne in sorted(G[this_path[-1]]):\n",
" # 如果相邻节点不在路径中,即不存在回路\n",
" if ne not in this_path:\n",
" # 则加入到待访问的路径中\n",
" to_search.append((this_path + ne,\n",
" this_path_dis + G[this_path[-1]][ne][\n",
" 'weight']))\n",
" return bfs_path, bfs_correct_path"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# 定义节点列表\n",
"node_list = ['A', 'B', 'C', 'D', 'E', 'F', 'G']\n",
"\n",
"# 定义边及权重列表\n",
"weighted_edges_list = [('A', 'B', 8), ('A', 'C', 20),\n",
" ('B', 'F', 40), ('B', 'E', 30),\n",
" ('B', 'D', 20), ('C', 'D', 10), \n",
" ('D', 'G', 10), ('D', 'E', 10),\n",
" ('E', 'F', 30), ('F', 'G', 30)]\n",
"\n",
"# 定义绘图中各个节点的坐标\n",
"nodes_pos = {\"A\": (1, 1), \"B\": (3, 3), \"C\": (5, 0), \"D\": (9, 2),\n",
" \"E\": (7, 4), \"F\": (6,6),\"G\": (11,5)}\n",
"\n",
"G = nx.Graph()\n",
"G.add_nodes_from(node_list)\n",
"G.add_weighted_edges_from(weighted_edges_list)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"bfs_path, bfs_correct_path = bfs_search(G, 3, 'A', 'G')"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"paths = [e[0] for e in bfs_path]"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"import collections\n",
"def get_search_tree_node_position(paths):\n",
" \"\"\"得到绘图时各个节点的坐标\n",
" \"\"\"\n",
" max_depth = 3 \n",
" # 得到每条路径的子路径\n",
" path_childern = {}\n",
" for path in paths:\n",
" father = path[:-1]\n",
" if father in paths:\n",
" if father in path_childern:\n",
" path_childern[father].append(path)\n",
" else:\n",
" path_childern[father] = [path]\n",
" # 对每条子路径排序\n",
" o_path_childern = collections.OrderedDict(\n",
" sorted(path_childern.items()))\n",
" # 计算每个树图中每个节点的位置\n",
" tree_node_position = {paths[0][0]: (1, 0, 2)}\n",
" for path, sub_paths in o_path_childern.items():\n",
" y_pos = -1.0 / max_depth * len(path)\n",
" dx = tree_node_position[path][2] / len(sub_paths)\n",
" sub_paths.sort()\n",
" for index, e_s in enumerate(sub_paths):\n",
" x_pos = tree_node_position[path][0] - tree_node_position[path][\n",
" 2] / 2 + dx / 2 + dx * index\n",
" tree_node_position[e_s] = (x_pos, y_pos, dx)\n",
" print(tree_node_position)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'A': (1, 0, 2), 'ACD': (1.5, -0.6666666666666666, 1.0), 'AC': (1.5, -0.3333333333333333, 1.0), 'ABD': (0.16666666666666666, -0.6666666666666666, 0.3333333333333333), 'AB': (0.5, -0.3333333333333333, 1.0), 'ABF': (0.8333333333333333, -0.6666666666666666, 0.3333333333333333), 'ABE': (0.5, -0.6666666666666666, 0.3333333333333333)}\n"
]
}
],
"source": [
"get_search_tree_node_position(paths)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import collections\n",
"from IPython import display\n",
"import networkx as nx\n",
"import numpy as np\n",
"import time\n",
"\n",
"\n",
"class SearchGraph():\n",
" def __init__(self,\n",
" node_list, \n",
" weighted_edges_list, \n",
" start_node,\n",
" target_node,\n",
" max_depth=1000,\n",
" nodes_pos=None,\n",
" help_info=None,):\n",
" # 图中的节点\n",
" self.node_list = node_list\n",
" self.weighted_edges_list = weighted_edges_list\n",
" self.start_node = start_node\n",
" self.target_node = target_node\n",
" self.nodes_pos = nodes_pos\n",
" self.max_depth = min(max_depth, len(node_list))\n",
" self.temp_best_path = None\n",
" \n",
" self.weighted_edges_dic = {frozenset([e[0],e[1]]):e[2] for e in weighted_edges_list}\n",
" self.help_info = help_info\n",
" self.path_score={self.start_node:0}\n",
" \n",
" self.animation_type = 'dfs'\n",
" \n",
" self.basic_node_color = '#6CB6FF'\n",
" self.start_node_color = 'y'\n",
" self.target_node_color = 'r'\n",
" self.visited_node_color = 'g'\n",
" \n",
" self.basic_edge_color = 'b'\n",
" self.visited_edge_color = 'g'\n",
" \n",
" self.success_color = 'r'\n",
" \n",
" self.correct_paths={}\n",
" self.show_correct_path = []\n",
" self.build_graph()\n",
" self.get_search_tree_node_position()\n",
" self.bfs_search()\n",
" \n",
" \n",
"\n",
" def build_graph(self):\n",
" self.G = nx.Graph()\n",
" self.G.add_nodes_from(self.node_list)\n",
" self.G.add_weighted_edges_from(self.weighted_edges_list)\n",
" \n",
" def get_search_tree_node_position(self):\n",
" \"\"\"得到绘图的点的坐标\n",
" \"\"\"\n",
" self.dfs_search()\n",
" # 得到 dfs 的搜索路径图\n",
" paths = self.dfs_path\n",
" # 得到每条路径的子路径\n",
" path_childern = {}\n",
" for path in paths:\n",
" father = path[:-1]\n",
" if father in paths:\n",
" if father in path_childern:\n",
" path_childern[father].append(path)\n",
" else:\n",
" path_childern[father] = [path]\n",
" # 对每条子路径排序\n",
" o_path_childern = collections.OrderedDict(sorted(path_childern.items()))\n",
" # 计算每个树图中每个节点的位置\n",
" tree_node_position = {self.start_node:(1, 0, 2)}\n",
" for path, sub_paths in o_path_childern.items():\n",
" y_pos = -1.0/self.max_depth * len(path)\n",
" dx = tree_node_position[path][2]/len(sub_paths) \n",
" sub_paths.sort()\n",
" for index, e_s in enumerate(sub_paths):\n",
" x_pos = tree_node_position[path][0] - tree_node_position[path][2]/2 + dx/2 + dx*index\n",
" tree_node_position[e_s]=(x_pos,y_pos, dx)\n",
" self.tree_node_position = tree_node_position\n",
" \n",
" def show_edge_labels(self, ax, pos1, pos2, label):\n",
" (x1, y1) = pos1\n",
" (x2, y2) = pos2\n",
" (x, y) = (x1*0.5 + x2*0.5, y1*0.5 + y2*0.5)\n",
"\n",
" angle = np.arctan2(y2 - y1, x2 - x1) / (2.0 * np.pi) * 360\n",
" if angle > 90:\n",
" angle -= 180\n",
" if angle < - 90:\n",
" angle += 180\n",
" xy = np.array((x, y))\n",
" trans_angle = ax.transData.transform_angles(np.array((angle,)),\n",
" xy.reshape((1, 2)))[0]\n",
" bbox = dict(boxstyle='round',\n",
" ec=(1.0, 1.0, 1.0),\n",
" fc=(1.0, 1.0, 1.0),\n",
" )\n",
" label = str(label) \n",
" ax.text(x, y,\n",
" label,\n",
" size=16,\n",
" color='k',\n",
" alpha=1,\n",
" horizontalalignment='center',\n",
" verticalalignment='center',\n",
" rotation=trans_angle,\n",
" transform=ax.transData,\n",
" bbox=bbox,\n",
" zorder=1,\n",
" clip_on=True,\n",
" )\n",
" \n",
" def show_search_tree(self, \n",
" this_path=None, \n",
" show_success_color=False,\n",
" best_path=None\n",
" ):\n",
" \"\"\"展示搜索树\n",
" \"\"\"\n",
" # 画出树图 \n",
" fig, ax = plt.subplots()\n",
" fig.set_figwidth(15)\n",
" fig.set_figheight(self.max_depth*1.5)\n",
" plt.axis('off')\n",
" \n",
" for path, pos in self.tree_node_position.items():\n",
" if path[-1] == self.start_node:\n",
" node_color = self.start_node_color\n",
" edge_color = self.basic_edge_color\n",
" elif this_path and path in this_path:\n",
" if show_success_color:\n",
" node_color = self.success_color\n",
" edge_color = self.success_color\n",
" else:\n",
" node_color = self.visited_node_color\n",
" edge_color = self.visited_edge_color\n",
" elif path[-1] == self.target_node:\n",
" node_color = self.target_node_color\n",
" edge_color = self.basic_edge_color\n",
" else:\n",
" node_color = self.basic_node_color\n",
" edge_color = self.basic_edge_color\n",
" ax.scatter(pos[0], pos[1], c=node_color, s=1000,zorder=1)\n",
" plt.annotate(\n",
" path[-1],\n",
" xy=(pos[0], pos[1]),\n",
" xytext=(0, 0),\n",
" textcoords='offset points',\n",
" ha='center',\n",
" va='center',\n",
" size=15,)\n",
" if len(path)>1:\n",
" plt.plot([self.tree_node_position[path[:-1]][0],pos[0]], \n",
" [self.tree_node_position[path[:-1]][1],pos[1]], \n",
" color=edge_color,\n",
" zorder=0)\n",
" if len(path)>1:\n",
" label = self.weighted_edges_dic[frozenset([path[-2],path[-1]])]\n",
" if self.animation_type in ['greedy','a_star']:\n",
" label = self.help_info_weight*self.help_info[path[-1]] + self.origin_info_weight*label\n",
" self.show_edge_labels(ax, self.tree_node_position[path[:-1]][0:2], pos[0:2], label)\n",
" display.clear_output(wait=True)\n",
" \n",
" show_res_text = \"\"\n",
" for e_c in self.show_correct_path:\n",
" show_res_text += '找到一条路径: %-7s' % e_c + '。距离为:' +str(self.correct_paths[e_c]) + '\\n'\n",
" plt.text(0, -1.1, show_res_text, fontsize=18,horizontalalignment='left', verticalalignment='top',)\n",
" \n",
" if best_path:\n",
" top_text = '最终最短路径为: %-7s' % this_path + '。距离为:' +str(self.correct_paths[this_path]) + '\\n'\n",
" elif this_path and self.animation_type in ['dfs','bfs']:\n",
" top_text = '当前路径: %-7s' % this_path + '。距离为:' +str(self.path_score[this_path]) + '\\n' \n",
" if self.temp_best_path:\n",
" top_text += '当前最短路径为: %-7s' % self.temp_best_path + '。距离为:' +str(self.correct_paths[self.temp_best_path]) + '\\n'\n",
" else:\n",
" top_text = ''\n",
"\n",
" plt.text(0, 0, \n",
" top_text, \n",
" fontsize=18,\n",
" horizontalalignment='left', \n",
" verticalalignment='top',)\n",
" \n",
" if self.animation_type in ['greedy','a_star']:\n",
" show_greedy_text = self.generate_greedy_help_text(this_path)\n",
" plt.text(0, 0, show_greedy_text, fontsize=18, horizontalalignment='left', verticalalignment='top',)\n",
" plt.show()\n",
" \n",
" def animation_search_tree(self,search_method='dfs', help_info_weight=1, origin_info_weight=1):\n",
" \"\"\"动画展示搜索过程\n",
" \"\"\"\n",
" self.animation_type = search_method\n",
" self.show_correct_path = []\n",
" self.temp_best_path = None\n",
" if search_method == 'bfs':\n",
" paths = self.bfs_path\n",
" elif search_method == 'dfs':\n",
" paths = self.dfs_path\n",
" elif search_method == 'greedy':\n",
" self.greedy_search()\n",
" paths = self.greedy_search_path\n",
" elif search_method == 'a_star':\n",
" self.a_star_search(help_info_weight=help_info_weight, origin_info_weight=origin_info_weight)\n",
" paths = self.greedy_search_path\n",
" else:\n",
" paths = []\n",
" for e_path in paths:\n",
" self.show_search_tree(e_path)\n",
" if e_path in self.correct_paths:\n",
" if not self.temp_best_path:\n",
" self.temp_best_path = e_path\n",
" elif self.path_score[e_path] < self.path_score[self.temp_best_path]:\n",
" self.temp_best_path = e_path\n",
" self.show_correct_path.append(e_path)\n",
" self.show_search_tree(e_path, True)\n",
" if search_method in ['greedy', 'a_star']:\n",
" time.sleep(5)\n",
" if search_method in ['bfs', 'dfs']:\n",
" if self.correct_paths:\n",
" best_path = min(self.correct_paths, key=self.correct_paths.get)\n",
" self.show_search_tree(best_path, True, True)\n",
" \n",
" def animation_graph(self, search_method='bfs', help_info_weight=1, origin_info_weight=1):\n",
" \n",
" \"\"\"\n",
" \"\"\"\n",
" self.animation_type = search_method\n",
" self.show_correct_path = []\n",
" if search_method == 'bfs':\n",
" paths = self.bfs_path\n",
" elif search_method == 'dfs':\n",
" paths = self.dfs_path\n",
" elif search_method == 'greedy':\n",
" self.greedy_search()\n",
" paths = self.greedy_search_path\n",
" elif search_method == 'a_star':\n",
" self.a_star_search(help_info_weight=help_info_weight, origin_info_weight=origin_info_weight)\n",
" paths = self.greedy_search_path\n",
" else:\n",
" paths = []\n",
" for e_path in paths:\n",
" self.show_graph(e_path)\n",
" if e_path in self.correct_paths:\n",
" self.show_correct_path.append(e_path)\n",
" self.show_graph(e_path, True)\n",
" time.sleep(5)\n",
" if search_method in ['bfs', 'dfs']:\n",
" best_path = min(self.correct_paths, key=self.correct_paths.get)\n",
" self.show_graph(best_path, True, True)\n",
" \n",
" def show_graph(self, this_path='', \n",
" show_success_color=False,\n",
" best_path=None):\n",
" \"\"\"\n",
" 绘制图\n",
" :return:\n",
" \"\"\"\n",
" fig, ax = plt.subplots()\n",
" fig.set_figwidth(6)\n",
" fig.set_figheight(8)\n",
" plt.axis('off')\n",
"\n",
" # 绘制节点与边颜色\n",
" visited_edges = []\n",
" if not this_path:\n",
" this_path = self.start_node\n",
" path_node_list = list(this_path)\n",
" for i in range(1,len(path_node_list)):\n",
" visited_edges.append(frozenset([path_node_list[i],path_node_list[i-1]]))\n",
" \n",
" # 节点与标识\n",
" nlabels = dict(zip(self.node_list, self.node_list))\n",
" edge_labels = dict([((u, v,), d['weight']) for u, v, d in self.G.edges(data=True)])\n",
" \n",
" # 节点颜色变化\n",
" val_map = {self.target_node: self.target_node_color}\n",
" if path_node_list:\n",
" for i in path_node_list:\n",
" if show_success_color:\n",
" val_map[i] = self.success_color\n",
" else:\n",
" val_map[i] = self.visited_node_color\n",
" val_map[self.start_node] = self.start_node_color \n",
" values = [val_map.get(node, self.basic_node_color) for node in self.G.nodes()]\n",
"\n",
" # 处理边的颜色\n",
" edge_colors = []\n",
" for edge in self.G.edges():\n",
" # 如果边在result_red_edges,分2种情况:\n",
" # 如果this_path[0]/this_path[-1] 对应起始点和终点,颜色为绿色,否则颜色为红色\n",
" # 如果边不在result_red_edges,则初始化边的颜色为黑色\n",
" if frozenset(edge) in visited_edges:\n",
" if show_success_color:\n",
" edge_colors.append(self.success_color)\n",
" else:\n",
" edge_colors.append(self.visited_edge_color)\n",
" else:\n",
" edge_colors.append(self.basic_edge_color)\n",
"\n",
" # 绘制节点及其标签\n",
" nx.draw_networkx_nodes(self.G, self.nodes_pos, node_size=800, node_color=values, width=6.0)\n",
" nx.draw_networkx_labels(self.G, self.nodes_pos, nlabels, font_size=20)\n",
" # 绘制边及其标签\n",
" nx.draw_networkx_edges(self.G, self.nodes_pos, edge_color=edge_colors, width=2.0, alpha=1.0)\n",
" nx.draw_networkx_edge_labels(self.G, self.nodes_pos, edge_labels=edge_labels, font_size=18)\n",
"\n",
" display.clear_output(wait=True)\n",
" # show_text = \"\"\n",
" # for e_c in self.show_correct_path:\n",
" # show_text += '找到一条路径: %-7s' % e_c + '。距离为:' +str(self.correct_paths[e_c]) + '\\n'\n",
" # plt.text(0, -2.6, show_text, fontsize=18, horizontalalignment='left', verticalalignment='top', )\n",
" \n",
"# if best_path:\n",
"# top_text = '最佳路径为: %-7s' % this_path + '。 距离为:' +str(self.correct_paths[this_path]) + '\\n'\n",
"# elif this_path and self.animation_type in ['dfs','bfs']:\n",
"# top_text = '当前路径: %-7s' % this_path + '。 距离为:' +str(self.cal_dis(this_path)) + '\\n'\n",
"# else:\n",
"# top_text = ''\n",
"# plt.text(0, 0, \n",
"# top_text, \n",
"# fontsize=18,\n",
"# horizontalalignment='left', \n",
"# verticalalignment='top',)\n",
" plt.show()\n",
" \n",
" def _dfs_helper(self, G, node, father, target_node,level, res, path):\n",
" path+=str(node)\n",
" if len(path)>1:\n",
" self.path_score[path] = self.path_score[path[:-1]] + self.weighted_edges_dic[frozenset([path[-2],path[-1]])]\n",
" res.append(path)\n",
" # 找到目标,停止搜索\n",
" if node==target_node:\n",
" return\n",
" if level< self.max_depth:\n",
" for neighbor in sorted(G[node]):\n",
" if str(neighbor) not in path:\n",
" self._dfs_helper(G, neighbor, node, target_node, level+1, res, path)\n",
" \n",
" def dfs_search(self):\n",
" dfs_path=[]\n",
" this_path=''\n",
" if self.start_node:\n",
" self._dfs_helper(self.G, self.start_node, None, self.target_node, 0, dfs_path, this_path)\n",
" self.dfs_path = dfs_path\n",
" for p in dfs_path:\n",
" if p[-1]==self.target_node and p not in self.correct_paths:\n",
" self.correct_paths[p] = self.cal_dis(p) \n",
" \n",
" def bfs_search(self):\n",
" to_search=[self.start_node]\n",
" bfs_path = []\n",
" bfs_correct_path = []\n",
" depth = 0\n",
" while to_search:\n",
" this_search = to_search.pop(0)\n",
" if len(this_search)>self.max_depth+1 :\n",
" break\n",
" bfs_path.append(this_search)\n",
" if this_search[-1]==self.target_node:\n",
" bfs_correct_path.append(this_search)\n",
" continue\n",
" for ne in sorted(self.G[this_search[-1]]):\n",
" if ne not in this_search:\n",
" to_search.append(this_search+ne)\n",
" self.bfs_path = bfs_path\n",
" for p in bfs_path:\n",
" if p[-1]==self.target_node and p not in self.correct_paths:\n",
" self.correct_paths[p] = self.cal_dis(p)\n",
" \n",
" def greedy_search(self, help_info_weight=1, origin_info_weight=0):\n",
" self.help_info_weight = help_info_weight\n",
" self.origin_info_weight = origin_info_weight\n",
" search_path = self.start_node\n",
" # 存储每一步的可选项及其分数,用来在动态演示时显示出来\n",
" search_scores = {}\n",
" while len(search_path) <= self.max_depth:\n",
" this_node = search_path[-1]\n",
" neighbour_nodes = [e_n for e_n in sorted(self.G[this_node]) if e_n not in search_path]\n",
" if len(neighbour_nodes) == 0:\n",
" search_scores[search_path]={}\n",
" break\n",
" if self.help_info:\n",
" scores = {e_n:help_info_weight*self.help_info[e_n]+origin_info_weight*self.weighted_edges_dic[frozenset([this_node,e_n])] for e_n in neighbour_nodes }\n",
" else:\n",
" scores = {e_n:self.weighted_edges_dic[frozenset([this_node,e_n])]\n",
" for e_n in neighbour_nodes }\n",
" search_scores[search_path]=scores\n",
" nearest_node = min(scores, key=scores.get)\n",
" search_path += nearest_node\n",
" if nearest_node == self.target_node:\n",
" break\n",
" self.greedy_search_path = [search_path[0:index+1] for index in range(len(search_path))]\n",
" self.search_scores = search_scores\n",
" \n",
" def a_star_search(self, help_info_weight=1, origin_info_weight=1):\n",
" self.greedy_search(help_info_weight, origin_info_weight)\n",
" \n",
"\n",
" def generate_greedy_help_text(self,path):\n",
" if path[-1] == self.target_node:\n",
" return '抵达目标节点' + str(self.target_node)\n",
" elif path not in self.search_scores:\n",
" return '抵达最大搜索深度,未找到目标节点'\n",
" \n",
" base_text = '当前可选的子节点及其信息值为 \\n'+ \\\n",
" str(self.search_scores[path]) + '\\n'\n",
" if self.target_node in self.search_scores[path]:\n",
" return base_text + '当前可选的子节点包含了目标节点,\\n所以选择目标节点'\n",
" elif len(self.search_scores[path]) == 1:\n",
" return base_text + '因为只有一个子节点,所以选择此节点'\n",
" else:\n",
" return base_text + '因为'+ \\\n",
" str(min(self.search_scores[path], key=self.search_scores[path].get)) + \\\n",
" '的值最小,所以选择此节点'\n",
" \n",
" def cal_dis(self,path):\n",
" dis = 0\n",
" if len(path) > 1:\n",
" for i in range(len(path)-1):\n",
" dis += self.weighted_edges_dic[frozenset([path[i],path[i+1]])]\n",
" return dis\n",
" "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}