{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import networkx as nx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def bfs_search(G, max_depth, start_node, target_node):\n",
    "    # 待访问的路径\n",
    "    to_search = [(start_node, 0)]\n",
    "    # 存储所有的历史路径，及此路径的距离\n",
    "    bfs_path = []\n",
    "    # 正确的路径列表，及此路径的距离\n",
    "    bfs_correct_path = []\n",
    "    # 当还有待访问的路径时\n",
    "    while to_search:\n",
    "        # 从待访问的路径中取第一个待访问路径及其路径长度，例如 AC\n",
    "        this_path, this_path_dis = to_search.pop(0)\n",
    "        # 如果待访问的路径达到最大搜索深度，跳出循环\n",
    "        if len(this_path) > max_depth :\n",
    "            break\n",
    "        # 把刚取出的路径存入历史路径中\n",
    "        bfs_path.append((this_path, this_path_dis))\n",
    "        # 如果路径的最后一个节点是目标节点，路径 AC 的最后一个节点是 C\n",
    "        if this_path[-1] == target_node:\n",
    "            # 其为一条正确的路径，将存入正确的路径列表中，\n",
    "            # 并不再继续往其子节点进行探索\n",
    "            bfs_correct_path.append((this_path, this_path_dis))\n",
    "            continue\n",
    "        # 找到路径最后一个节点的相邻节点\n",
    "        for ne in sorted(G[this_path[-1]]):\n",
    "            # 如果相邻节点不在路径中，即不存在回路\n",
    "            if ne not in this_path:\n",
    "                # 则加入到待访问的路径中\n",
    "                to_search.append((this_path + ne,\n",
    "                                  this_path_dis + G[this_path[-1]][ne][\n",
    "                                      'weight']))\n",
    "    return bfs_path, bfs_correct_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义节点列表\n",
    "node_list = ['A', 'B', 'C', 'D', 'E', 'F', 'G']\n",
    "\n",
    "# 定义边及权重列表\n",
    "weighted_edges_list = [('A', 'B', 8), ('A', 'C', 20),\n",
    "                       ('B', 'F', 40), ('B', 'E', 30),\n",
    "                       ('B', 'D', 20), ('C', 'D', 10), \n",
    "                       ('D', 'G', 10), ('D', 'E', 10),\n",
    "                       ('E', 'F', 30), ('F', 'G', 30)]\n",
    "\n",
    "# 定义绘图中各个节点的坐标\n",
    "nodes_pos = {\"A\": (1, 1), \"B\": (3, 3), \"C\": (5, 0), \"D\": (9, 2),\n",
    "             \"E\": (7, 4), \"F\": (6,6),\"G\": (11,5)}\n",
    "\n",
    "G = nx.Graph()\n",
    "G.add_nodes_from(node_list)\n",
    "G.add_weighted_edges_from(weighted_edges_list)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "bfs_path, bfs_correct_path = bfs_search(G, 3, 'A', 'G')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "paths = [e[0] for e in bfs_path]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "def get_search_tree_node_position(paths):\n",
    "    \"\"\"得到绘图时各个节点的坐标\n",
    "    \"\"\"\n",
    "    max_depth = 3 \n",
    "    # 得到每条路径的子路径\n",
    "    path_childern = {}\n",
    "    for path in paths:\n",
    "        father = path[:-1]\n",
    "        if father in paths:\n",
    "            if father in path_childern:\n",
    "                path_childern[father].append(path)\n",
    "            else:\n",
    "                path_childern[father] = [path]\n",
    "    # 对每条子路径排序\n",
    "    o_path_childern = collections.OrderedDict(\n",
    "        sorted(path_childern.items()))\n",
    "    # 计算每个树图中每个节点的位置\n",
    "    tree_node_position = {paths[0][0]: (1, 0, 2)}\n",
    "    for path, sub_paths in o_path_childern.items():\n",
    "        y_pos = -1.0 / max_depth * len(path)\n",
    "        dx = tree_node_position[path][2] / len(sub_paths)\n",
    "        sub_paths.sort()\n",
    "        for index, e_s in enumerate(sub_paths):\n",
    "            x_pos = tree_node_position[path][0] - tree_node_position[path][\n",
    "                2] / 2 + dx / 2 + dx * index\n",
    "            tree_node_position[e_s] = (x_pos, y_pos, dx)\n",
    "    print(tree_node_position)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'A': (1, 0, 2), 'ACD': (1.5, -0.6666666666666666, 1.0), 'AC': (1.5, -0.3333333333333333, 1.0), 'ABD': (0.16666666666666666, -0.6666666666666666, 0.3333333333333333), 'AB': (0.5, -0.3333333333333333, 1.0), 'ABF': (0.8333333333333333, -0.6666666666666666, 0.3333333333333333), 'ABE': (0.5, -0.6666666666666666, 0.3333333333333333)}\n"
     ]
    }
   ],
   "source": [
    "get_search_tree_node_position(paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import collections\n",
    "from IPython import display\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import time\n",
    "\n",
    "\n",
    "class SearchGraph():\n",
    "    def __init__(self,\n",
    "                 node_list, \n",
    "                 weighted_edges_list, \n",
    "                 start_node,\n",
    "                 target_node,\n",
    "                 max_depth=1000,\n",
    "                 nodes_pos=None,\n",
    "                 help_info=None,):\n",
    "        # 图中的节点\n",
    "        self.node_list = node_list\n",
    "        self.weighted_edges_list = weighted_edges_list\n",
    "        self.start_node = start_node\n",
    "        self.target_node = target_node\n",
    "        self.nodes_pos = nodes_pos\n",
    "        self.max_depth = min(max_depth, len(node_list))\n",
    "        self.temp_best_path = None\n",
    "        \n",
    "        self.weighted_edges_dic = {frozenset([e[0],e[1]]):e[2] for e in weighted_edges_list}\n",
    "        self.help_info = help_info\n",
    "        self.path_score={self.start_node:0}\n",
    "        \n",
    "        self.animation_type = 'dfs'\n",
    "        \n",
    "        self.basic_node_color = '#6CB6FF'\n",
    "        self.start_node_color = 'y'\n",
    "        self.target_node_color = 'r'\n",
    "        self.visited_node_color = 'g'\n",
    "        \n",
    "        self.basic_edge_color = 'b'\n",
    "        self.visited_edge_color = 'g'\n",
    "        \n",
    "        self.success_color = 'r'\n",
    "        \n",
    "        self.correct_paths={}\n",
    "        self.show_correct_path = []\n",
    "        self.build_graph()\n",
    "        self.get_search_tree_node_position()\n",
    "        self.bfs_search()\n",
    "        \n",
    "        \n",
    "\n",
    "    def build_graph(self):\n",
    "        self.G = nx.Graph()\n",
    "        self.G.add_nodes_from(self.node_list)\n",
    "        self.G.add_weighted_edges_from(self.weighted_edges_list)\n",
    "        \n",
    "    def get_search_tree_node_position(self):\n",
    "        \"\"\"得到绘图的点的坐标\n",
    "        \"\"\"\n",
    "        self.dfs_search()\n",
    "        # 得到 dfs 的搜索路径图\n",
    "        paths = self.dfs_path\n",
    "        # 得到每条路径的子路径\n",
    "        path_childern = {}\n",
    "        for path in paths:\n",
    "            father = path[:-1]\n",
    "            if father in paths:\n",
    "                if father in path_childern:\n",
    "                    path_childern[father].append(path)\n",
    "                else:\n",
    "                    path_childern[father] = [path]\n",
    "        # 对每条子路径排序\n",
    "        o_path_childern = collections.OrderedDict(sorted(path_childern.items()))\n",
    "        # 计算每个树图中每个节点的位置\n",
    "        tree_node_position = {self.start_node:(1, 0, 2)}\n",
    "        for path, sub_paths in o_path_childern.items():\n",
    "            y_pos = -1.0/self.max_depth * len(path)\n",
    "            dx = tree_node_position[path][2]/len(sub_paths) \n",
    "            sub_paths.sort()\n",
    "            for index, e_s in enumerate(sub_paths):\n",
    "                x_pos = tree_node_position[path][0] - tree_node_position[path][2]/2 + dx/2 + dx*index\n",
    "                tree_node_position[e_s]=(x_pos,y_pos, dx)\n",
    "        self.tree_node_position = tree_node_position\n",
    "        \n",
    "    def show_edge_labels(self, ax, pos1, pos2, label):\n",
    "        (x1, y1) = pos1\n",
    "        (x2, y2) = pos2\n",
    "        (x, y) = (x1*0.5 + x2*0.5, y1*0.5 + y2*0.5)\n",
    "\n",
    "        angle = np.arctan2(y2 - y1, x2 - x1) / (2.0 * np.pi) * 360\n",
    "        if angle > 90:\n",
    "            angle -= 180\n",
    "        if angle < - 90:\n",
    "            angle += 180\n",
    "        xy = np.array((x, y))\n",
    "        trans_angle = ax.transData.transform_angles(np.array((angle,)),\n",
    "                                                    xy.reshape((1, 2)))[0]\n",
    "        bbox = dict(boxstyle='round',\n",
    "                    ec=(1.0, 1.0, 1.0),\n",
    "                    fc=(1.0, 1.0, 1.0),\n",
    "                    )\n",
    "        label = str(label) \n",
    "        ax.text(x, y,\n",
    "                    label,\n",
    "                    size=16,\n",
    "                    color='k',\n",
    "                    alpha=1,\n",
    "                    horizontalalignment='center',\n",
    "                    verticalalignment='center',\n",
    "                    rotation=trans_angle,\n",
    "                    transform=ax.transData,\n",
    "                    bbox=bbox,\n",
    "                    zorder=1,\n",
    "                    clip_on=True,\n",
    "                    )\n",
    "        \n",
    "    def show_search_tree(self, \n",
    "                         this_path=None, \n",
    "                         show_success_color=False,\n",
    "                         best_path=None\n",
    "                        ):\n",
    "        \"\"\"展示搜索树\n",
    "        \"\"\"\n",
    "        # 画出树图        \n",
    "        fig, ax = plt.subplots()\n",
    "        fig.set_figwidth(15)\n",
    "        fig.set_figheight(self.max_depth*1.5)\n",
    "        plt.axis('off')\n",
    "        \n",
    "        for path, pos in self.tree_node_position.items():\n",
    "            if path[-1] == self.start_node:\n",
    "                node_color = self.start_node_color\n",
    "                edge_color = self.basic_edge_color\n",
    "            elif this_path and path in this_path:\n",
    "                if show_success_color:\n",
    "                    node_color = self.success_color\n",
    "                    edge_color = self.success_color\n",
    "                else:\n",
    "                    node_color = self.visited_node_color\n",
    "                    edge_color = self.visited_edge_color\n",
    "            elif path[-1] == self.target_node:\n",
    "                node_color = self.target_node_color\n",
    "                edge_color = self.basic_edge_color\n",
    "            else:\n",
    "                node_color = self.basic_node_color\n",
    "                edge_color = self.basic_edge_color\n",
    "            ax.scatter(pos[0], pos[1], c=node_color, s=1000,zorder=1)\n",
    "            plt.annotate(\n",
    "                path[-1],\n",
    "                xy=(pos[0], pos[1]),\n",
    "                xytext=(0, 0),\n",
    "                textcoords='offset points',\n",
    "                ha='center',\n",
    "                va='center',\n",
    "                size=15,)\n",
    "            if len(path)>1:\n",
    "                plt.plot([self.tree_node_position[path[:-1]][0],pos[0]], \n",
    "                         [self.tree_node_position[path[:-1]][1],pos[1]], \n",
    "                         color=edge_color,\n",
    "                         zorder=0)\n",
    "                if len(path)>1:\n",
    "                    label = self.weighted_edges_dic[frozenset([path[-2],path[-1]])]\n",
    "                    if self.animation_type in ['greedy','a_star']:\n",
    "                            label = self.help_info_weight*self.help_info[path[-1]] + self.origin_info_weight*label\n",
    "                    self.show_edge_labels(ax, self.tree_node_position[path[:-1]][0:2], pos[0:2], label)\n",
    "        display.clear_output(wait=True)\n",
    "        \n",
    "        show_res_text = \"\"\n",
    "        for e_c in self.show_correct_path:\n",
    "            show_res_text += '找到一条路径: %-7s' % e_c + '。距离为:' +str(self.correct_paths[e_c]) + '\\n'\n",
    "        plt.text(0, -1.1, show_res_text, fontsize=18,horizontalalignment='left', verticalalignment='top',)\n",
    "        \n",
    "        if best_path:\n",
    "            top_text = '最终最短路径为: %-7s' % this_path + '。距离为:' +str(self.correct_paths[this_path]) + '\\n'\n",
    "        elif this_path and  self.animation_type in ['dfs','bfs']:\n",
    "            top_text = '当前路径: %-7s' % this_path + '。距离为:' +str(self.path_score[this_path]) + '\\n' \n",
    "            if self.temp_best_path:\n",
    "                top_text += '当前最短路径为: %-7s' % self.temp_best_path + '。距离为:' +str(self.correct_paths[self.temp_best_path]) + '\\n'\n",
    "        else:\n",
    "            top_text = ''\n",
    "\n",
    "        plt.text(0, 0, \n",
    "                 top_text, \n",
    "                 fontsize=18,\n",
    "                 horizontalalignment='left', \n",
    "                 verticalalignment='top',)\n",
    "        \n",
    "        if self.animation_type in ['greedy','a_star']:\n",
    "            show_greedy_text = self.generate_greedy_help_text(this_path)\n",
    "            plt.text(0, 0, show_greedy_text, fontsize=18, horizontalalignment='left', verticalalignment='top',)\n",
    "        plt.show()\n",
    "        \n",
    "    def animation_search_tree(self,search_method='dfs', help_info_weight=1, origin_info_weight=1):\n",
    "        \"\"\"动画展示搜索过程\n",
    "        \"\"\"\n",
    "        self.animation_type = search_method\n",
    "        self.show_correct_path = []\n",
    "        self.temp_best_path = None\n",
    "        if search_method == 'bfs':\n",
    "            paths = self.bfs_path\n",
    "        elif search_method == 'dfs':\n",
    "            paths = self.dfs_path\n",
    "        elif search_method == 'greedy':\n",
    "            self.greedy_search()\n",
    "            paths = self.greedy_search_path\n",
    "        elif search_method == 'a_star':\n",
    "            self.a_star_search(help_info_weight=help_info_weight, origin_info_weight=origin_info_weight)\n",
    "            paths = self.greedy_search_path\n",
    "        else:\n",
    "            paths = []\n",
    "        for e_path in paths:\n",
    "            self.show_search_tree(e_path)\n",
    "            if e_path in self.correct_paths:\n",
    "                if not self.temp_best_path:\n",
    "                    self.temp_best_path = e_path\n",
    "                elif self.path_score[e_path] < self.path_score[self.temp_best_path]:\n",
    "                    self.temp_best_path = e_path\n",
    "                self.show_correct_path.append(e_path)\n",
    "                self.show_search_tree(e_path, True)\n",
    "            if search_method in ['greedy', 'a_star']:\n",
    "                time.sleep(5)\n",
    "        if search_method in ['bfs', 'dfs']:\n",
    "            if self.correct_paths:\n",
    "                best_path = min(self.correct_paths, key=self.correct_paths.get)\n",
    "                self.show_search_tree(best_path, True, True)\n",
    "    \n",
    "    def animation_graph(self, search_method='bfs', help_info_weight=1, origin_info_weight=1):\n",
    "        \n",
    "        \"\"\"\n",
    "        \"\"\"\n",
    "        self.animation_type = search_method\n",
    "        self.show_correct_path = []\n",
    "        if search_method == 'bfs':\n",
    "            paths = self.bfs_path\n",
    "        elif search_method == 'dfs':\n",
    "            paths = self.dfs_path\n",
    "        elif search_method == 'greedy':\n",
    "            self.greedy_search()\n",
    "            paths = self.greedy_search_path\n",
    "        elif search_method == 'a_star':\n",
    "            self.a_star_search(help_info_weight=help_info_weight, origin_info_weight=origin_info_weight)\n",
    "            paths = self.greedy_search_path\n",
    "        else:\n",
    "            paths = []\n",
    "        for e_path in paths:\n",
    "            self.show_graph(e_path)\n",
    "            if e_path in self.correct_paths:\n",
    "                self.show_correct_path.append(e_path)\n",
    "                self.show_graph(e_path, True)\n",
    "            time.sleep(5)\n",
    "        if search_method in ['bfs', 'dfs']:\n",
    "            best_path = min(self.correct_paths, key=self.correct_paths.get)\n",
    "            self.show_graph(best_path, True, True)\n",
    "    \n",
    "    def show_graph(self, this_path='', \n",
    "                         show_success_color=False,\n",
    "                         best_path=None):\n",
    "        \"\"\"\n",
    "        绘制图\n",
    "        :return:\n",
    "        \"\"\"\n",
    "        fig, ax = plt.subplots()\n",
    "        fig.set_figwidth(6)\n",
    "        fig.set_figheight(8)\n",
    "        plt.axis('off')\n",
    "\n",
    "        # 绘制节点与边颜色\n",
    "        visited_edges = []\n",
    "        if not this_path:\n",
    "            this_path = self.start_node\n",
    "        path_node_list = list(this_path)\n",
    "        for i in range(1,len(path_node_list)):\n",
    "            visited_edges.append(frozenset([path_node_list[i],path_node_list[i-1]]))\n",
    "            \n",
    "        # 节点与标识\n",
    "        nlabels = dict(zip(self.node_list, self.node_list))\n",
    "        edge_labels = dict([((u, v,), d['weight']) for u, v, d in self.G.edges(data=True)])\n",
    "        \n",
    "        # 节点颜色变化\n",
    "        val_map = {self.target_node: self.target_node_color}\n",
    "        if path_node_list:\n",
    "            for i in path_node_list:\n",
    "                if show_success_color:\n",
    "                    val_map[i] = self.success_color\n",
    "                else:\n",
    "                    val_map[i] = self.visited_node_color\n",
    "        val_map[self.start_node] = self.start_node_color \n",
    "        values = [val_map.get(node, self.basic_node_color) for node in self.G.nodes()]\n",
    "\n",
    "        # 处理边的颜色\n",
    "        edge_colors = []\n",
    "        for edge in self.G.edges():\n",
    "            # 如果边在result_red_edges,分2种情况:\n",
    "            # 如果this_path[0]/this_path[-1] 对应起始点和终点，颜色为绿色，否则颜色为红色\n",
    "            # 如果边不在result_red_edges,则初始化边的颜色为黑色\n",
    "            if frozenset(edge) in visited_edges:\n",
    "                if show_success_color:\n",
    "                    edge_colors.append(self.success_color)\n",
    "                else:\n",
    "                    edge_colors.append(self.visited_edge_color)\n",
    "            else:\n",
    "                edge_colors.append(self.basic_edge_color)\n",
    "\n",
    "        # 绘制节点及其标签\n",
    "        nx.draw_networkx_nodes(self.G, self.nodes_pos, node_size=800, node_color=values, width=6.0)\n",
    "        nx.draw_networkx_labels(self.G, self.nodes_pos, nlabels, font_size=20)\n",
    "        # 绘制边及其标签\n",
    "        nx.draw_networkx_edges(self.G, self.nodes_pos, edge_color=edge_colors, width=2.0, alpha=1.0)\n",
    "        nx.draw_networkx_edge_labels(self.G, self.nodes_pos, edge_labels=edge_labels, font_size=18)\n",
    "\n",
    "        display.clear_output(wait=True)\n",
    "        # show_text = \"\"\n",
    "        # for e_c in self.show_correct_path:\n",
    "        #     show_text += '找到一条路径: %-7s' % e_c + '。距离为:' +str(self.correct_paths[e_c]) + '\\n'\n",
    "        # plt.text(0, -2.6, show_text, fontsize=18, horizontalalignment='left', verticalalignment='top', )\n",
    "        \n",
    "#         if best_path:\n",
    "#             top_text = '最佳路径为: %-7s' % this_path + '。 距离为:' +str(self.correct_paths[this_path]) + '\\n'\n",
    "#         elif this_path and  self.animation_type in ['dfs','bfs']:\n",
    "#             top_text = '当前路径: %-7s' % this_path + '。 距离为:' +str(self.cal_dis(this_path)) + '\\n'\n",
    "#         else:\n",
    "#             top_text = ''\n",
    "#         plt.text(0, 0, \n",
    "#                  top_text, \n",
    "#                  fontsize=18,\n",
    "#                  horizontalalignment='left', \n",
    "#                  verticalalignment='top',)\n",
    "        plt.show()\n",
    "        \n",
    "    def _dfs_helper(self, G, node,  father, target_node,level, res, path):\n",
    "        path+=str(node)\n",
    "        if len(path)>1:\n",
    "            self.path_score[path] = self.path_score[path[:-1]] + self.weighted_edges_dic[frozenset([path[-2],path[-1]])]\n",
    "        res.append(path)\n",
    "        # 找到目标，停止搜索\n",
    "        if node==target_node:\n",
    "            return\n",
    "        if level< self.max_depth:\n",
    "            for neighbor in sorted(G[node]):\n",
    "                if str(neighbor) not in path:\n",
    "                    self._dfs_helper(G, neighbor,  node, target_node, level+1, res, path)\n",
    "                    \n",
    "    def dfs_search(self):\n",
    "        dfs_path=[]\n",
    "        this_path=''\n",
    "        if self.start_node:\n",
    "            self._dfs_helper(self.G, self.start_node, None, self.target_node, 0, dfs_path, this_path)\n",
    "        self.dfs_path = dfs_path\n",
    "        for p in dfs_path:\n",
    "            if p[-1]==self.target_node and p not in self.correct_paths:\n",
    "                self.correct_paths[p] = self.cal_dis(p) \n",
    "        \n",
    "    def bfs_search(self):\n",
    "        to_search=[self.start_node]\n",
    "        bfs_path = []\n",
    "        bfs_correct_path = []\n",
    "        depth = 0\n",
    "        while to_search:\n",
    "            this_search = to_search.pop(0)\n",
    "            if len(this_search)>self.max_depth+1 :\n",
    "                break\n",
    "            bfs_path.append(this_search)\n",
    "            if this_search[-1]==self.target_node:\n",
    "                bfs_correct_path.append(this_search)\n",
    "                continue\n",
    "            for ne in sorted(self.G[this_search[-1]]):\n",
    "                if ne not in this_search:\n",
    "                    to_search.append(this_search+ne)\n",
    "        self.bfs_path = bfs_path\n",
    "        for p in bfs_path:\n",
    "            if p[-1]==self.target_node and p not in self.correct_paths:\n",
    "                self.correct_paths[p] = self.cal_dis(p)\n",
    "                \n",
    "    def greedy_search(self, help_info_weight=1, origin_info_weight=0):\n",
    "        self.help_info_weight = help_info_weight\n",
    "        self.origin_info_weight = origin_info_weight\n",
    "        search_path = self.start_node\n",
    "        # 存储每一步的可选项及其分数，用来在动态演示时显示出来\n",
    "        search_scores = {}\n",
    "        while len(search_path) <= self.max_depth:\n",
    "            this_node = search_path[-1]\n",
    "            neighbour_nodes = [e_n for e_n in sorted(self.G[this_node]) if e_n not in search_path]\n",
    "            if len(neighbour_nodes) == 0:\n",
    "                search_scores[search_path]={}\n",
    "                break\n",
    "            if self.help_info:\n",
    "                scores = {e_n:help_info_weight*self.help_info[e_n]+origin_info_weight*self.weighted_edges_dic[frozenset([this_node,e_n])] for e_n in neighbour_nodes }\n",
    "            else:\n",
    "                scores = {e_n:self.weighted_edges_dic[frozenset([this_node,e_n])]\n",
    "                                                             for e_n in neighbour_nodes }\n",
    "            search_scores[search_path]=scores\n",
    "            nearest_node = min(scores, key=scores.get)\n",
    "            search_path += nearest_node\n",
    "            if nearest_node == self.target_node:\n",
    "                break\n",
    "        self.greedy_search_path = [search_path[0:index+1] for index in range(len(search_path))]\n",
    "        self.search_scores = search_scores\n",
    "        \n",
    "    def a_star_search(self, help_info_weight=1, origin_info_weight=1):\n",
    "        self.greedy_search(help_info_weight, origin_info_weight)\n",
    "        \n",
    "\n",
    "    def generate_greedy_help_text(self,path):\n",
    "        if path[-1] == self.target_node:\n",
    "            return '抵达目标节点' + str(self.target_node)\n",
    "        elif path not in self.search_scores:\n",
    "            return '抵达最大搜索深度，未找到目标节点'\n",
    "        \n",
    "        base_text = '当前可选的子节点及其信息值为 \\n'+ \\\n",
    "                    str(self.search_scores[path]) + '\\n'\n",
    "        if self.target_node in self.search_scores[path]:\n",
    "            return base_text + '当前可选的子节点包含了目标节点，\\n所以选择目标节点'\n",
    "        elif len(self.search_scores[path]) == 1:\n",
    "            return base_text + '因为只有一个子节点，所以选择此节点'\n",
    "        else:\n",
    "            return base_text + '因为'+ \\\n",
    "                    str(min(self.search_scores[path], key=self.search_scores[path].get)) + \\\n",
    "                    '的值最小，所以选择此节点'\n",
    "          \n",
    "    def cal_dis(self,path):\n",
    "        dis = 0\n",
    "        if len(path) > 1:\n",
    "            for i in range(len(path)-1):\n",
    "                dis += self.weighted_edges_dic[frozenset([path[i],path[i+1]])]\n",
    "        return dis\n",
    "    "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
