import matplotlib.pyplot as plt
import collections
from IPython import display
import networkx as nx
import numpy as np
import time


class Color:
    # 绘图配色
    # 基本的节点配色
    basic_node_color = '#6CB6FF'
    # 初始节点配色
    start_node_color = 'y'
    # 目标节点配色
    target_node_color = 'r'
    # 已访问的节点配色
    visited_node_color = 'g'
    # 基本边的配色
    basic_edge_color = 'b'
    # 已访问的边的配色
    visited_edge_color = 'g'
    # 抵达目标节点时的配色
    success_color = 'r'


class Graph(nx.Graph):
    def __init__(self,
                 start_node=None,
                 target_node=None):

        # 初始化父类
        nx.Graph.__init__(self)
        # 绘制 graph 时各个节点的坐标值
        self.nodes_pos = {}
        # 搜索起点
        self.start_node = start_node
        # 搜搜目标点
        self.target_node = target_node
        # 搜索的最大深度
        self.max_depth = None
        # bfs 的历史查找路径，例如 ['A','AB', 'AD', 'ABC']
        self.bfs_paths = []
        # dfs 的历史查找路径，例如 ['A','AB', 'ABC', 'AD']
        self.dfs_paths = []
        # 每条历史路径的分数（历史距离）
        self.path_score = {}
        # 绘制搜索数时每个节点的坐标值
        self.tree_node_position = {}
        # 额外的辅助信息值，用于进行启发式搜索（贪婪或 A-star 搜索）
        self.help_info = {}
        # 额外辅助信息的权重
        self.help_info_weight = 1
        # 原始信息的权重
        self.origin_info_weight = 1
        # a_star 算法的历史查找路径
        self.a_star_search_paths = []
        # a_star 算法的中路径的得分
        self.a_star_search_scores = {}
        #
        self.changed = True

    def set_start_node(self, start_node):
        """
        设置起点
        :param start_node: 起点
        :return:
        """
        self.start_node = start_node
        self.path_score = {self.start_node: 0}
        self.changed = True

    def set_target_node(self, target_node):
        """
        设置目标点
        :param target_node: 目标点
        :return:
        """
        self.target_node = target_node
        self.path_score = {self.start_node: 0}
        self.changed = True

    def set_nodes_pos(self, nodes_pos):
        """
        设置 graph 内各个节点的坐标，绘图用，与搜索无关
        :param nodes_pos: 字典，节点的坐标，
                          例如 {"A": (1, 1), "B": (3, 3), "C": (5, 0)}
        :return:
        """
        self.nodes_pos = nodes_pos

    def set_max_depth(self, max_depth):
        """
        设置最大搜索深度
        :param max_depth: 整数，大于 0
        :return:
        """
        max_depth = min(max_depth, len(self.nodes))
        self.max_depth = max_depth
        self.changed = True

    def set_help_info(self, help_info):
        """
        设置辅助信息，用于进行启发式搜索（贪婪或 A-star 搜索）
        :param help_info: 字典，
                          例如 {'A': 30, 'B': 20, 'C': 19 }，为各个点到目标点的距离
        :return:
        """
        self.help_info = help_info

    def show_graph(self, this_path=''):
        """
        绘制 graph
        :param this_path: 设置一条图中的路径
        :return:
        """
        # 当不传入路径时，默认在初始节点
        this_path = this_path or self.start_node
        # 根据当前你路径，处理节点和边的颜色
        # 根据路径得到已访问的边，例如 'ABC' 得到 ['AB', 'BC']
        visited_edges = []
        for i in range(1, len(this_path)):
            visited_edges.append(
                frozenset([this_path[i], this_path[i - 1]]))
        # 节点和边以及其显示的标签
        node_labels = dict(zip(self.nodes(), self.nodes()))
        edge_labels = dict(
            [((u, v,), d['weight']) for u, v, d in self.edges(data=True)])

        # 处理节点的颜色
        node_color_map = {e_node: Color.visited_node_color for e_node in
                          this_path}
        node_color_map[self.start_node] = Color.start_node_color
        node_color_map[self.target_node] = Color.target_node_color
        node_color = [node_color_map.get(node, Color.basic_node_color)
                      for node in self.nodes()]

        # 处理每条边的颜色
        edge_color = []
        for edge in self.edges():
            if frozenset(edge) in visited_edges:
                edge_color.append(Color.visited_edge_color)
            else:
                edge_color.append(Color.basic_edge_color)

        # 创建绘图
        fig, ax = plt.subplots()
        # 定义绘图的宽和高，并关闭坐标轴的显示
        fig.set_figwidth(6)
        fig.set_figheight(8)
        plt.axis('off')

        # 绘制节点及其标签
        nx.draw_networkx_nodes(self, self.nodes_pos, node_size=800,
                               node_color=node_color, width=6.0)
        nx.draw_networkx_labels(self, self.nodes_pos, node_labels,
                                font_size=20)
        # 绘制边及其标签
        nx.draw_networkx_edges(self, self.nodes_pos, edge_color=edge_color,
                               width=2.0, alpha=1.0)
        nx.draw_networkx_edge_labels(self, self.nodes_pos,
                                     edge_labels=edge_labels, font_size=18)
        # 清除绘图区，显示新绘图
        display.clear_output(wait=True)
        plt.show()

    def bfs_search(self):
        """
        使用迭代法进行广度优先搜索
        :return:
        """
        # 待访问的路径
        to_search = [self.start_node]
        # 存储所有的已访问的路径
        bfs_paths = []
        # 当还有待访问的路径时
        while to_search:
            # 从待访问的路径中取第一个待访问路径
            this_search = to_search.pop(0)
            # 如果待访问的路径超过最大搜索深度，跳出循环
            if len(this_search) > self.max_depth + 1:
                break
            # 把刚取出的路径存入已访问的路径中
            bfs_paths.append(this_search)
            # 如果路径的最后一个节点是目标节点，路径 AC 的最后一个节点是 C
            if this_search[-1] == self.target_node:
                # 其为一条正确的路径，将存入正确的路径列表中，
                # 并不再继续往其子节点进行探索
                continue
            # 找到路径最后一个节点的相邻节点
            else:
                for ne in sorted(self.neighbors(this_search[-1])):
                    # 如果相邻节点不在路径中，即不存在回路
                    if ne not in this_search:
                        # 则加入到待访问的路径中
                        to_search.append(this_search + ne)
        self.bfs_paths = bfs_paths

    def _dfs_helper(self, node, target_node, level, dfs_paths, path):
        """
        深度优先搜索的辅助函数
        :param node: 当前节点
        :param target_node: 目标点
        :param level: 搜索深度
        :param dfs_paths: dfs 的历史搜索路径
        :param path: 从哪一个路径来到当前节点
        :return:
        """
        path += str(node)
        # 更新路径的分数（距离）
        if len(path) > 1:
            self.path_score[path] = self.path_score[path[:-1]] + \
                                    self.edges[path[-2], path[-1]]['weight']
        # 存储 dfs 的历史搜索路径
        dfs_paths.append(path)
        # 找到目标，停止搜索
        if node == target_node:
            return
        # 未达到最大搜索深度时，继续下一层搜索
        if level < self.max_depth:
            # 对当前节点的每一个相邻节点
            for neighbor in sorted(self.neighbors(node)):
                # 如果该相邻节点不在路径中，即没有出现回环，则递归调用，继续往下搜索
                if str(neighbor) not in path:
                    self._dfs_helper(neighbor, target_node, level + 1,
                                     dfs_paths, path)

    def dfs_search(self):
        """
        使用递归法进行深度优先搜索
        :return:
        """
        # dfs 的历史搜索路径
        dfs_paths = []
        this_path = ''
        if self.start_node and self.target_node:
            self._dfs_helper(self.start_node, self.target_node,
                             0, dfs_paths, this_path)
        else:
            print('请设置起点和目标点')
        self.dfs_paths = dfs_paths
        # 完成搜索后，可得到搜索树中各个节点的坐标
        self.get_search_tree_node_position()

    def get_search_tree_node_position(self):
        """得到绘图时各个节点的坐标
        """
        # 得到 dfs 的搜索路径图
        paths = self.dfs_paths
        # 得到每条路径的子路径
        path_children = {}
        for path in paths:
            father = path[:-1]
            if father in paths:
                if father in path_children:
                    path_children[father].append(path)
                else:
                    path_children[father] = [path]
        # 对每条子路径排序
        o_path_children = collections.OrderedDict(
            sorted(path_children.items()))
        # 计算每个树图中每个节点的位置
        tree_node_position = {self.start_node: (1, 0, 2)}
        for path, sub_paths in o_path_children.items():
            y_pos = -1.0 / self.max_depth * len(path)
            dx = tree_node_position[path][2] / len(sub_paths)
            sub_paths.sort()
            for index, e_s in enumerate(sub_paths):
                x_pos = tree_node_position[path][0] - tree_node_position[path][
                    2] / 2 + dx / 2 + dx * index
                tree_node_position[e_s] = (x_pos, y_pos, dx)
        self.tree_node_position = tree_node_position

    def a_star_search(self, help_info_weight=1, origin_info_weight=0):
        """
        a-star 搜索， 当 origin_info_weight 为 0 时，则退化为贪婪搜索
        :param help_info_weight: 辅助信息的比重
        :param origin_info_weight: 原始信息的比重
        :return:
        """
        # 存到类属性中，便于绘图时使用
        self.help_info_weight = help_info_weight
        self.origin_info_weight = origin_info_weight
        # 初始路径为起点
        search_path = self.start_node
        # 存储每一步的可选项及其分数，用来在动态演示时显示出来
        search_scores = {}
        # 当搜索路径未超过最大搜索深度
        while len(search_path) <= self.max_depth:
            # 当前的节点
            this_node = search_path[-1]
            # 当前节点的子节点
            neighbour_nodes = [e_n for e_n in sorted(self.neighbors(this_node))
                               if e_n not in search_path]
            # 如果没有子节点，则跳出循环，结束搜索
            if len(neighbour_nodes) == 0:
                search_scores[search_path] = {}
                break
            # 计算每个子节点的得分并存储
            scores = {e_n: help_info_weight * self.help_info[
                e_n] + origin_info_weight * self.edges[this_node, e_n][
                               'weight'] for e_n in
                      neighbour_nodes}
            search_scores[search_path] = scores
            # 挑选最佳的子节点，并添加到路径中
            nearest_node = min(scores, key=scores.get)
            search_path += nearest_node
            # 如果最佳的子节点是目标节点，跳出循环，结束搜索
            if nearest_node == self.target_node:
                break
        # 把最终路径切分为每一步，便于动态展示，例如 ABCD 变为 [A, AB, ABC, ABCD ]
        self.a_star_search_paths = [search_path[0:index + 1] for index in
                                    range(len(search_path))]
        self.a_star_search_scores = search_scores

    def greedy_search(self):
        """
        贪婪搜索，就是 origin_info_weight权重为 0 时的 a-star 搜索
        :return:
        """
        self.a_star_search(help_info_weight=1, origin_info_weight=0)

    @staticmethod
    def show_edge_labels(ax, pos1, pos2, label):
        """
        绘制搜索树的边
        :param ax: 子图
        :param pos1: 点 1 的坐标
        :param pos2: 点 2 的坐标
        :param label: 连接点 1 和点 2 的边上的文字
        :return:
        """
        # 点1
        (x1, y1) = pos1
        # 点2
        (x2, y2) = pos2
        # 文字的位置
        (x, y) = (x1 * 0.5 + x2 * 0.5, y1 * 0.5 + y2 * 0.5)
        # 文字的角度
        angle = np.arctan2(y2 - y1, x2 - x1) / (2.0 * np.pi) * 360
        if angle > 90:
            angle -= 180
        if angle < - 90:
            angle += 180
        xy = np.array((x, y))
        trans_angle = ax.transData.transform_angles(np.array((angle,)),
                                                    xy.reshape((1, 2)))[0]
        # 绘制文字框和文字
        bbox = dict(boxstyle='round',
                    ec=(1.0, 1.0, 1.0),
                    fc=(1.0, 1.0, 1.0),
                    )
        label = str(label)
        ax.text(x, y,
                label,
                size=16,
                color='k',
                alpha=1,
                horizontalalignment='center',
                verticalalignment='center',
                rotation=trans_angle,
                transform=ax.transData,
                bbox=bbox,
                zorder=1,
                clip_on=True,
                )

    def show_search_tree(self,
                         animation_type='bfs',
                         top_text='',
                         bottom_text='',
                         this_path=None,
                         show_success_color=False,
                         ):
        """
        展示搜索树，动态展示搜索过程时，会调用此方法
        :param animation_type: 动态演示的类型，如果是启发式搜索，边的权重需要变化
        :param top_text: 上方的文字展示
        :param bottom_text: 下方的文字展示
        :param this_path: 当前路径
        :param show_success_color: 成功找到目标点后，路径颜色的变换
        :return:
        """
        # 如果对起始点，目标点或搜索深度进行了设置，需要重新绘制搜索树
        if self.changed is True:
            self.bfs_search()
            self.dfs_search()
            self.changed = False

        # 创建子图
        fig, ax = plt.subplots()
        # 定义绘图的宽度
        fig.set_figwidth(15)
        # 定义绘图的高度
        fig.set_figheight(self.max_depth * 1.5)
        # 关闭绘图中坐标轴的显示
        plt.axis('off')

        # 对每条路径
        for path, pos in self.tree_node_position.items():
            #  如果是初始点
            if path[-1] == self.start_node:
                node_color = Color.start_node_color
                edge_color = Color.basic_edge_color
            # 把当前路径的节点和边的颜色变为已访问的颜色
            elif this_path and path in this_path:
                # 是否显示成功找到目标点
                if show_success_color:
                    node_color = Color.success_color
                    edge_color = Color.success_color
                else:
                    node_color = Color.visited_node_color
                    edge_color = Color.visited_edge_color
            # 如果路径的终点是目标点，改变目标点的颜色
            elif path[-1] == self.target_node:
                node_color = Color.target_node_color
                edge_color = Color.basic_edge_color
            # 其他的情况下，节点和边的颜色是正常色
            else:
                node_color = Color.basic_node_color
                edge_color = Color.basic_edge_color
            # 绘制节点
            ax.scatter(pos[0], pos[1], c=node_color, s=1000, zorder=1)
            # 绘制节点的标注
            plt.annotate(
                path[-1],
                xy=(pos[0], pos[1]),
                xytext=(0, 0),
                textcoords='offset points',
                ha='center',
                va='center',
                size=15, )
            if len(path) > 1:
                # 绘制边
                plt.plot([self.tree_node_position[path[:-1]][0], pos[0]],
                         [self.tree_node_position[path[:-1]][1], pos[1]],
                         color=edge_color,
                         zorder=0)
                # 绘制边的标注
                label = self.edges[path[-2], path[-1]]['weight']

                if animation_type in ['greedy', 'a_star']:
                    label = self.help_info_weight * self.help_info[
                        path[-1]] + self.origin_info_weight * label
                self.show_edge_labels(ax,
                                      self.tree_node_position[path[:-1]][
                                      0:2], pos[0:2], label)

        # 绘制上方文字
        plt.text(0,
                 0,
                 top_text,
                 fontsize=18,
                 horizontalalignment='left',
                 verticalalignment='top', )
        # 绘制下方文字
        plt.text(0,
                 -1.1,
                 bottom_text,
                 fontsize=18,
                 horizontalalignment='left',
                 verticalalignment='top', )

        # 刷新绘图
        display.clear_output(wait=True)
        plt.show()

    def _generate_bottom_text(self, show_correct_path):
        """
        生成目前找到的最佳路径的信息的文字
        :param show_correct_path: 当前找到的正确的路径
        :return:
        """
        # 默认不展示文字
        bottom_text = ""
        # 对每一条找到的正确路径，增加一条展示文本s
        for path in show_correct_path:
            bottom_text += '找到一条路径: %-7s' % path + '。距离为:' + str(
                self.path_score[path]) + '\n'
        return bottom_text

    def _generate_a_star_help_text(self, path):
        """
        生成贪婪搜索和 a-star 动态展示时的文字
        :param path: 当前路径
        :return:
        """
        # 如果到达目标节点
        if path[-1] == self.target_node:
            return '抵达目标节点' + str(self.target_node)
        # 如果未抵达目标节点并且抵达了最大搜索深度
        elif path not in self.a_star_search_scores:
            return '未找到目标节点， 结束搜索'
        # 其他情况，展示当前节点可选节点的信息，以及挑选的原因
        else:
            base_text = '当前可选的子节点及其信息值为 \n' + \
                        str(self.a_star_search_scores[path]) + '\n'
            if self.target_node in self.a_star_search_scores[path]:
                return base_text + '当前可选的子节点包含了目标节点，\n所以选择目标节点'
            elif len(self.a_star_search_scores[path]) == 1:
                return base_text + '因为只有一个子节点，所以选择此节点'
            else:
                return base_text + '因为' + \
                       str(min(self.a_star_search_scores[path],
                               key=self.a_star_search_scores[path].get)) + \
                       '的值最小，所以选择此节点'

    def _generate_top_text(self,
                           animation_type,
                           this_path,
                           best_path=None,
                           finish=False):
        """
        生成展示当前路径的信息文字
        :param animation_type:
        :param this_path:
        :param best_path:
        :param finish:
        :return:
        """
        # 如果结束搜索，展示最终的最短路径
        if finish:
            top_text = '最终最短路径为: %-7s' % this_path + '。距离为:' + str(
                self.path_score[this_path]) + '\n'
        # 如果是其他路径，并且是展示 dfs 或 bfs 的搜索过程，展示当前路径的信息
        elif this_path and animation_type in ['dfs', 'bfs']:
            top_text = '当前路径: %-7s' % this_path + '。距离为:' + str(
                self.path_score[this_path]) + '\n'
            if best_path:
                top_text += '当前最短路径为: %-7s' % best_path + '。距离为:' + \
                            str(self.path_score[best_path]) + '\n'
        # 如果是其他路径，并且是展示 贪婪搜索 或 A-star 算法的搜索过程，展示当前路径的信息
        elif this_path and animation_type in ['greedy', 'a_star']:
            top_text = self._generate_a_star_help_text(this_path)
        # 其他情况，不展示文字
        else:
            top_text = ''
        return top_text

    def animate_search_tree(self,
                            animation_type='dfs',
                            help_info_weight=1,
                            origin_info_weight=1,
                            sleep_time=0):
        """
        动态演示搜索过程
        :param animation_type: 可选项为 ['bfs', 'dfs', 'greedy', 'a_star']
        :param help_info_weight: 附加信息的权重值
        :param origin_info_weight: 原始信息的权重值
        :param sleep_time: 设置每一步的等待时间
        :return:
        """
        # 根据展示的搜索方式，获取展示的路径列表
        if animation_type == 'bfs':
            paths = self.bfs_paths
        elif animation_type == 'dfs':
            paths = self.dfs_paths
        elif animation_type == 'greedy':
            self.greedy_search()
            paths = self.a_star_search_paths
        elif animation_type == 'a_star':
            self.a_star_search(help_info_weight=help_info_weight,
                               origin_info_weight=origin_info_weight)
            paths = self.a_star_search_paths
        else:
            print('animation_type 参数错误，请从 dfs、bfs、 greedy 或 a_star 中挑选一个')
            return

        if animation_type in ['bfs', 'dfs']:
            show_correct_path = []
            # 动态演示过程中找到的最佳路径
            best_path = None
            # 对路径列表中的每一个路径，绘图
            for e_path in paths:
                top_text = self._generate_top_text(animation_type,
                                                   e_path,
                                                   best_path=best_path,
                                                   finish=False)
                bottom_text = self._generate_bottom_text(show_correct_path)
                self.show_search_tree(top_text=top_text,
                                      bottom_text=bottom_text,
                                      this_path=e_path)
                # 设置等待时间，避免切换过快
                time.sleep(sleep_time)
                # 如果该路径是正确路径
                if e_path[-1] == self.target_node:
                    # 如果是第一个正确路径，则其为当前最佳路径
                    if not best_path:
                        best_path = e_path
                    # 如果不是，与当前最佳路径比较，
                    elif self.path_score[e_path] < self.path_score[best_path]:
                        best_path = e_path
                    # 增加一条最佳路径的展示
                    show_correct_path.append(e_path)
                    bottom_text = self._generate_bottom_text(show_correct_path)
                    self.show_search_tree(top_text=top_text,
                                          bottom_text=bottom_text,
                                          this_path=e_path,
                                          show_success_color=True)
                    # 设置等待时间，避免切换过快
                    time.sleep(sleep_time)
            # 搜索结束后，展示最佳路径
            top_text = self._generate_top_text(animation_type,
                                               best_path,
                                               best_path=best_path,
                                               finish=True
                                               )
            bottom_text = self._generate_bottom_text(show_correct_path)
            self.show_search_tree(top_text=top_text,
                                  bottom_text=bottom_text,
                                  this_path=best_path,
                                  show_success_color=True)

        else:
            # 对路径列表中的每一个路径，绘图
            for e_path in paths:
                top_text = self._generate_top_text(animation_type,
                                                   e_path,
                                                   best_path=False)
                self.show_search_tree(top_text=top_text,
                                      this_path=e_path)
                # 设置等待时间，避免切换过快
                time.sleep(sleep_time)
                # 如果抵达目标点
                if e_path[-1] == self.target_node:
                    top_text = self._generate_top_text(animation_type,
                                                       e_path,
                                                       best_path=True)
                    self.show_search_tree(top_text=top_text,
                                          this_path=e_path,
                                          show_success_color=True)
