diff --git a/_README.ipynb b/_README.ipynb index 75867ff..52c3206 100644 --- a/_README.ipynb +++ b/_README.ipynb @@ -17,6 +17,50 @@ " - ```_README.md```*-----说明文档*\n", " - ```app_spec.yml```*-----定义项目的输入输出,为部署服务*\n", " - ```coding_here.ipynb```*-----输入并运行代码*" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def handle(conf):\n", + " \"\"\"\n", + " 该方法是部署之后,其他人调用你的服务时候的处理方法。\n", + " 请按规范填写参数结构,这样我们就能替你自动生成配置文件,方便其他人的调用。\n", + " 范例:\n", + " params['key'] = value # value_type: str # description: some description\n", + " value_type 可以选择:img, video, audio, str, int, float, [int], [str], [float]\n", + " 参数请放到params字典中,我们会自动解析该变量。\n", + " \"\"\"\n", + "\n", + " param1 = conf['param1'] # value_type: str # description: some description\n", + " # add your code\n", + " return {'ret1': 'cat'}\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def handle(conf):\n", + " \"\"\"\n", + " 该方法是部署之后,其他人调用你的服务时候的处理方法。\n", + " 请按规范填写参数结构,这样我们就能替你自动生成配置文件,方便其他人的调用。\n", + " 范例:\n", + " params['key'] = value # value_type: str # description: some description\n", + " value_type 可以选择:img, video, audio, str, int, float, [int], [str], [float]\n", + " 参数请放到params字典中,我们会自动解析该变量。\n", + " \"\"\"\n", + "\n", + " param1 = conf['param1'] # value_type: str # description: some description\n", + " # add your code\n", + " return {'ret1': 'cat'}\n", + " " ] }, { @@ -135,17 +179,17 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.5.2" - }, + "version": "3.7.5" + }, "pycharm": { "stem_cell": { "cell_type": "raw", - "source": [], "metadata": { "collapsed": false - } - } - } + }, + "source": [] + } + } }, "nbformat": 4, "nbformat_minor": 2 diff --git a/app_spec.yml b/app_spec.yml new file mode 100644 index 0000000..62866f8 --- /dev/null +++ b/app_spec.yml @@ -0,0 +1,18 @@ +input: + start_words: + name: start_words + value_type: str + description: 想要开始的第一个字 + prefix_words: + name: prefix_words + value_type: str + description: 想要的诗歌语境 + max_gen_len: + name: max_gen_len + value_type: int + description: 想要的诗歌长度 +output: + result: + name: result + value_type: str + description: 你的诗歌 diff --git a/checkpoints/tang_199.pth b/checkpoints/tang_199.pth new file mode 100644 index 0000000..29e23db Binary files /dev/null and b/checkpoints/tang_199.pth differ diff --git a/coding_here.ipynb b/coding_here.ipynb index 90d9432..bc37cec 100644 --- a/coding_here.ipynb +++ b/coding_here.ipynb @@ -1,36 +1,60 @@ - - { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print('Hello Mo!')" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.5.2" - } - }, - "nbformat": 4, - "nbformat_minor": 2 - } - \ No newline at end of file +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def handle(conf):\n", + " \"\"\"\n", + " 该方法是部署之后,其他人调用你的服务时候的处理方法。\n", + " 请按规范填写参数结构,这样我们就能替你自动生成配置文件,方便其他人的调用。\n", + " 范例:\n", + " params['key'] = value # value_type: str # description: some description\n", + " value_type 可以选择:img, video, audio, str, int, float, [int], [str], [float]\n", + " 参数请放到params字典中,我们会自动解析该变量。\n", + " \"\"\"\n", + "\n", + "# param1 = conf['param1'] # value_type: str # description: some description\n", + " start_words = conf['start_words'] #诗歌开始\n", + " prefix_words = conf['prefix_words'] #诗歌语境\n", + " max_gen_len = conf['max_gen_len'] #诗歌最大长度\n", + " # add your code\n", + " if __name__ == '__main__':\n", + " result = main.gen(start_words,prefix_words,max_gen_len)\n", + " result = ''.json(result)\n", + " return {'你的诗歌': 'result'}\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/data.py b/data.py new file mode 100644 index 0000000..42e902c --- /dev/null +++ b/data.py @@ -0,0 +1,194 @@ +# coding:utf-8 +import sys +import os +import json +import re +import numpy as np + + +def _parseRawData(author=None, constrain=None, src='./chinese-poetry/json/simplified', category="poet.tang"): + """ + code from https://github.com/justdark/pytorch-poetry-gen/blob/master/dataHandler.py + 处理json文件,返回诗歌内容 + @param: author: 作者名字 + @param: constrain: 长度限制 + @param: src: json 文件存放路径 + @param: category: 类别,有poet.song 和 poet.tang + + 返回 data:list + ['床前明月光,疑是地上霜,举头望明月,低头思故乡。', + '一去二三里,烟村四五家,亭台六七座,八九十支花。', + ......... + ] + """ + + def sentenceParse(para): + # para 形如 "-181-村橋路不端,數里就迴湍。積壤連涇脉,高林上笋竿。早嘗甘蔗淡, + # 生摘琵琶酸。(「琵琶」,嚴壽澄校《張祜詩集》云:疑「枇杷」之誤。) + # 好是去塵俗,煙花長一欄。" + result, number = re.subn(u"(.*)", "", para) + result, number = re.subn(u"{.*}", "", result) + result, number = re.subn(u"《.*》", "", result) + result, number = re.subn(u"《.*》", "", result) + result, number = re.subn(u"[\]\[]", "", result) + r = "" + for s in result: + if s not in set('0123456789-'): + r += s + r, number = re.subn(u"。。", u"。", r) + return r + + def handleJson(file): + # print file + rst = [] + data = json.loads(open(file).read()) + for poetry in data: + pdata = "" + if (author is not None and poetry.get("author") != author): + continue + p = poetry.get("paragraphs") + flag = False + for s in p: + sp = re.split(u"[,!。]", s) + for tr in sp: + if constrain is not None and len(tr) != constrain and len(tr) != 0: + flag = True + break + if flag: + break + if flag: + continue + for sentence in poetry.get("paragraphs"): + pdata += sentence + pdata = sentenceParse(pdata) + if pdata != "": + rst.append(pdata) + return rst + + data = [] + for filename in os.listdir(src): + if filename.startswith(category): + data.extend(handleJson(src + filename)) + return data + + +def pad_sequences(sequences, + maxlen=None, + dtype='int32', + padding='pre', + truncating='pre', + value=0.): + """ + code from keras + Pads each sequence to the same length (length of the longest sequence). + If maxlen is provided, any sequence longer + than maxlen is truncated to maxlen. + Truncation happens off either the beginning (default) or + the end of the sequence. + Supports post-padding and pre-padding (default). + Arguments: + sequences: list of lists where each element is a sequence + maxlen: int, maximum length + dtype: type to cast the resulting sequence. + padding: 'pre' or 'post', pad either before or after each sequence. + truncating: 'pre' or 'post', remove values from sequences larger than + maxlen either in the beginning or in the end of the sequence + value: float, value to pad the sequences to the desired value. + Returns: + x: numpy array with dimensions (number_of_sequences, maxlen) + Raises: + ValueError: in case of invalid values for `truncating` or `padding`, + or in case of invalid shape for a `sequences` entry. + """ + if not hasattr(sequences, '__len__'): + raise ValueError('`sequences` must be iterable.') + lengths = [] + for x in sequences: + if not hasattr(x, '__len__'): + raise ValueError('`sequences` must be a list of iterables. ' + 'Found non-iterable: ' + str(x)) + lengths.append(len(x)) + + num_samples = len(sequences) + if maxlen is None: + maxlen = np.max(lengths) + + # take the sample shape from the first non empty sequence + # checking for consistency in the main loop below. + sample_shape = tuple() + for s in sequences: + if len(s) > 0: # pylint: disable=g-explicit-length-test + sample_shape = np.asarray(s).shape[1:] + break + + x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype) + for idx, s in enumerate(sequences): + if not len(s): # pylint: disable=g-explicit-length-test + continue # empty list/array was found + if truncating == 'pre': + trunc = s[-maxlen:] # pylint: disable=invalid-unary-operand-type + elif truncating == 'post': + trunc = s[:maxlen] + else: + raise ValueError('Truncating type "%s" not understood' % truncating) + + # check `trunc` has expected shape + trunc = np.asarray(trunc, dtype=dtype) + if trunc.shape[1:] != sample_shape: + raise ValueError( + 'Shape of sample %s of sequence at position %s is different from ' + 'expected shape %s' + % (trunc.shape[1:], idx, sample_shape)) + + if padding == 'post': + x[idx, :len(trunc)] = trunc + elif padding == 'pre': + x[idx, -len(trunc):] = trunc + else: + raise ValueError('Padding type "%s" not understood' % padding) + return x + + +def get_data(opt): + """ + @param opt 配置选项 Config对象 + @return word2ix: dict,每个字对应的序号,形如u'月'->100 + @return ix2word: dict,每个序号对应的字,形如'100'->u'月' + @return data: numpy数组,每一行是一首诗对应的字的下标 + """ + if os.path.exists(opt.pickle_path): + data = np.load(opt.pickle_path, allow_pickle=True) + data, word2ix, ix2word = data['data'], data['word2ix'].item(), data['ix2word'].item() + return data, word2ix, ix2word + + # 如果没有处理好的二进制文件,则处理原始的json文件 + data = _parseRawData(opt.author, opt.constrain, opt.data_path, opt.category) + words = {_word for _sentence in data for _word in _sentence} + word2ix = {_word: _ix for _ix, _word in enumerate(words)} + word2ix[''] = len(word2ix) # 终止标识符 + word2ix[''] = len(word2ix) # 起始标识符 + word2ix[''] = len(word2ix) # 空格 + ix2word = {_ix: _word for _word, _ix in list(word2ix.items())} + + # 为每首诗歌加上起始符和终止符 + for i in range(len(data)): + data[i] = [""] + list(data[i]) + [""] + + # 将每首诗歌保存的内容由‘字’变成‘数’ + # 形如[春,江,花,月,夜]变成[1,2,3,4,5] + new_data = [[word2ix[_word] for _word in _sentence] + for _sentence in data] + + # 诗歌长度不够opt.maxlen的在前面补空格,超过的,删除末尾的 + pad_data = pad_sequences(new_data, + maxlen=opt.maxlen, + padding='pre', + truncating='post', + value=len(word2ix) - 1) + + # 保存成二进制文件 + np.savez_compressed(opt.pickle_path, + data=pad_data, + word2ix=word2ix, + ix2word=ix2word) + return pad_data, word2ix, ix2word diff --git a/handler.py b/handler.py new file mode 100644 index 0000000..bbe29cd --- /dev/null +++ b/handler.py @@ -0,0 +1,20 @@ +def handle(conf): + """ + 该方法是部署之后,其他人调用你的服务时候的处理方法。 + 请按规范填写参数结构,这样我们就能替你自动生成配置文件,方便其他人的调用。 + 范例: + params['key'] = value # value_type: str # description: some description + value_type 可以选择:img, video, audio, str, int, float, [int], [str], [float] + 参数请放到params字典中,我们会自动解析该变量。 + """ + +# param1 = conf['param1'] # value_type: str # description: some description + start_words = conf['start_words'] #诗歌开始 + prefix_words = conf['prefix_words'] #诗歌语境 + max_gen_len = conf['max_gen_len'] #诗歌最大长度 + # add your code + if __name__ == '__main__': + result = main.gen(start_words,prefix_words,max_gen_len) + result = ''.json(result) + return {'你的诗歌': 'result'} + \ No newline at end of file diff --git a/job_logs/job-cpu-6243b431b38562dfcd9e9e0e.log b/job_logs/job-cpu-6243b431b38562dfcd9e9e0e.log new file mode 100644 index 0000000..b5c4472 --- /dev/null +++ b/job_logs/job-cpu-6243b431b38562dfcd9e9e0e.log @@ -0,0 +1,7 @@ +2022-03-30T01:36:55.737591372Z SYSTEM: Preparing env... +2022-03-30T01:36:57.042653704Z SYSTEM: Running... +2022-03-30T01:37:00.202120599Z Traceback (most recent call last): +2022-03-30T01:37:00.202162227Z File "main.py", line 4, in +2022-03-30T01:37:00.202308379Z from data import get_data +2022-03-30T01:37:00.202332663Z ModuleNotFoundError: No module named 'data' +2022-03-30T01:37:00.318630589Z SYSTEM: Finishing... diff --git a/job_logs/job-gpu-6243b389c415503e59e51369.log b/job_logs/job-gpu-6243b389c415503e59e51369.log new file mode 100644 index 0000000..8098e91 --- /dev/null +++ b/job_logs/job-gpu-6243b389c415503e59e51369.log @@ -0,0 +1,7 @@ +2022-03-30T01:34:08.011406308Z SYSTEM: Preparing env... +2022-03-30T01:34:08.572530654Z SYSTEM: Running... +2022-03-30T01:34:09.509088239Z Traceback (most recent call last): +2022-03-30T01:34:09.509150878Z File "main.py", line 4, in +2022-03-30T01:34:09.514459087Z from data import get_data +2022-03-30T01:34:09.514480993Z ModuleNotFoundError: No module named 'data' +2022-03-30T01:34:09.612018351Z SYSTEM: Finishing... diff --git a/main.py b/main.py new file mode 100644 index 0000000..1a39930 --- /dev/null +++ b/main.py @@ -0,0 +1,234 @@ +# coding:utf8 +import sys, os +import torch as t +from data import get_data +from model import PoetryModel +from torch import nn +from utils import Visualizer +import tqdm +from torchnet import meter +import ipdb + + +class Config(object): + data_path = 'tang.npz' # 诗歌的文本文件存放路径 + pickle_path = 'tang.npz' # 预处理好的二进制文件 + author = None # 只学习某位作者的诗歌 + constrain = None # 长度限制 + category = 'poet.tang' # 类别,唐诗还是宋诗歌(poet.song) + lr = 1e-3 + weight_decay = 1e-4 + use_gpu = False + epoch = 20 + batch_size = 128 + maxlen = 125 # 超过这个长度的之后字被丢弃,小于这个长度的在前面补空格 + plot_every = 20 # 每20个batch 可视化一次 + # use_env = True # 是否使用visodm + env = 'poetry' # visdom env + max_gen_len = 200 # 生成诗歌最长长度 + debug_file = '/tmp/debugp' + model_path = 'checkpoints/tang_199.pth' # 预训练模型路径 + prefix_words = '美女' # 不是诗歌的组成部分,用来控制生成诗歌的意境 + start_words = '雨' # 诗歌开始 + acrostic = False # 是否是藏头诗 + model_prefix = 'checkpoints/tang' # 模型保存路径 + + +opt = Config() + + +def generate(model, start_words, ix2word, word2ix, prefix_words=None): + """ + 给定几个词,根据这几个词接着生成一首完整的诗歌 + start_words:u'春江潮水连海平' + 比如start_words 为 春江潮水连海平,可以生成: + + """ + + results = list(start_words) + start_word_len = len(start_words) + # 手动设置第一个词为 + input = t.Tensor([word2ix['']]).view(1, 1).long() + if opt.use_gpu: input = input.cuda() + hidden = None + + if prefix_words: + for word in prefix_words: + output, hidden = model(input, hidden) + # print(input.data) + # input = input.data.new([word2ix[word]]).view(1, 1) + + for i in range(opt.max_gen_len): + output, hidden = model(input, hidden) + + if i < start_word_len: + w = results[i] + input = input.data.new([word2ix[w]]).view(1, 1) + else: + top_index = output.data[0].topk(1)[1][0].item() + w = ix2word[top_index] + results.append(w) + input = input.data.new([top_index]).view(1, 1) + if w == '': + del results[-1] + break + return results + + +def gen_acrostic(model, start_words, ix2word, word2ix, prefix_words=None): + """ + 生成藏头诗 + start_words : u'深度学习' + 生成: + 深木通中岳,青苔半日脂。 + 度山分地险,逆浪到南巴。 + 学道兵犹毒,当时燕不移。 + 习根通古岸,开镜出清羸。 + """ + results = [] + start_word_len = len(start_words) + input = (t.Tensor([word2ix['']]).view(1, 1).long()) + if opt.use_gpu: input = input.cuda() + hidden = None + + index = 0 # 用来指示已经生成了多少句藏头诗 + # 上一个词 + pre_word = '' + + if prefix_words: + for word in prefix_words: + output, hidden = model(input, hidden) + input = (input.data.new([word2ix[word]])).view(1, 1) + + for i in range(opt.max_gen_len): + output, hidden = model(input, hidden) + top_index = output.data[0].topk(1)[1][0].item() + w = ix2word[top_index] + + if (pre_word in {u'。', u'!', ''}): + # 如果遇到句号,藏头的词送进去生成 + + if index == start_word_len: + # 如果生成的诗歌已经包含全部藏头的词,则结束 + break + else: + # 把藏头的词作为输入送入模型 + w = start_words[index] + index += 1 + input = (input.data.new([word2ix[w]])).view(1, 1) + else: + # 否则的话,把上一次预测是词作为下一个词输入 + input = (input.data.new([word2ix[w]])).view(1, 1) + results.append(w) + pre_word = w + return results + + +def train(**kwargs): + for k, v in kwargs.items(): + setattr(opt, k, v) + + opt.device=t.device('cuda') if opt.use_gpu else t.device('cpu') + device = opt.device + vis = Visualizer(env=opt.env) + + # 获取数据 + data, word2ix, ix2word = get_data(opt) + data = t.from_numpy(data) + dataloader = t.utils.data.DataLoader(data, + batch_size=opt.batch_size, + shuffle=True, + num_workers=1) + + # 模型定义 + model = PoetryModel(len(word2ix), 128, 256) + optimizer = t.optim.Adam(model.parameters(), lr=opt.lr) + criterion = nn.CrossEntropyLoss() + if opt.model_path: + model.load_state_dict(t.load(opt.model_path)) + model.to(device) + + loss_meter = meter.AverageValueMeter() + for epoch in range(opt.epoch): + loss_meter.reset() + for ii, data_ in tqdm.tqdm(enumerate(dataloader)): + + # 训练 + data_ = data_.long().transpose(1, 0).contiguous() + data_ = data_.to(device) + optimizer.zero_grad() + input_, target = data_[:-1, :], data_[1:, :] + output, _ = model(input_) + loss = criterion(output, target.view(-1)) + loss.backward() + optimizer.step() + + loss_meter.add(loss.item()) + + # 可视化 + if (1 + ii) % opt.plot_every == 0: + + if os.path.exists(opt.debug_file): + ipdb.set_trace() + + vis.plot('loss', loss_meter.value()[0]) + + # 诗歌原文 + poetrys = [[ix2word[_word] for _word in data_[:, _iii].tolist()] + for _iii in range(data_.shape[1])][:16] + vis.text('
'.join([''.join(poetry) for poetry in poetrys]), win=u'origin_poem') + + gen_poetries = [] + # 分别以这几个字作为诗歌的第一个字,生成8首诗 + for word in list(u'春江花月夜凉如水'): + gen_poetry = ''.join(generate(model, word, ix2word, word2ix)) + gen_poetries.append(gen_poetry) + vis.text('
'.join([''.join(poetry) for poetry in gen_poetries]), win=u'gen_poem') + + t.save(model.state_dict(), '%s_%s.pth' % (opt.model_prefix, epoch)) + + +def gen(**kwargs): + """ + 提供命令行接口,用以生成相应的诗 + """ + + for k, v in kwargs.items(): + setattr(opt, k, v) + data, word2ix, ix2word = get_data(opt) + model = PoetryModel(len(word2ix), 128, 256); + map_location = lambda s, l: s + state_dict = t.load(opt.model_path, map_location=map_location) + model.load_state_dict(state_dict) + + if opt.use_gpu: + model.cuda() + + # python2和python3 字符串兼容 + if sys.version_info.major == 3: + if opt.start_words.isprintable(): + start_words = opt.start_words + prefix_words = opt.prefix_words if opt.prefix_words else None + else: + start_words = opt.start_words.encode('ascii', 'surrogateescape').decode('utf8') + prefix_words = opt.prefix_words.encode('ascii', 'surrogateescape').decode( + 'utf8') if opt.prefix_words else None + else: + start_words = opt.start_words.decode('utf8') + prefix_words = opt.prefix_words.decode('utf8') if opt.prefix_words else None + + start_words = start_words.replace(',', u',') \ + .replace('.', u'。') \ + .replace('?', u'?') + + gen_poetry = gen_acrostic if opt.acrostic else generate + result = gen_poetry(model, start_words, ix2word, word2ix, prefix_words) + return result + # print(''.join(result)) + + +if __name__ == '__main__': + # import fire + # + # fire.Fire() + gen() diff --git a/model.py b/model.py new file mode 100644 index 0000000..74cbb71 --- /dev/null +++ b/model.py @@ -0,0 +1,31 @@ +# coding:utf8 +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class PoetryModel(nn.Module): + def __init__(self, vocab_size, embedding_dim, hidden_dim): + super(PoetryModel, self).__init__() + self.hidden_dim = hidden_dim + self.embeddings = nn.Embedding(vocab_size, embedding_dim) + self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, num_layers=2) + self.linear1 = nn.Linear(self.hidden_dim, vocab_size) + + def forward(self, input, hidden=None): + seq_len, batch_size = input.size() + if hidden is None: + # h_0 = 0.01*torch.Tensor(2, batch_size, self.hidden_dim).normal_().cuda() + # c_0 = 0.01*torch.Tensor(2, batch_size, self.hidden_dim).normal_().cuda() + h_0 = input.data.new(2, batch_size, self.hidden_dim).fill_(0).float() + c_0 = input.data.new(2, batch_size, self.hidden_dim).fill_(0).float() + else: + h_0, c_0 = hidden + # size: (seq_len,batch_size,embeding_dim) + embeds = self.embeddings(input) + # output size: (seq_len,batch_size,hidden_dim) + output, hidden = self.lstm(embeds, (h_0, c_0)) + + # size: (seq_len*batch_size,vocab_size) + output = self.linear1(output.view(seq_len * batch_size, -1)) + return output, hidden diff --git a/project_requirements.txt b/project_requirements.txt new file mode 100644 index 0000000..5f2496a --- /dev/null +++ b/project_requirements.txt @@ -0,0 +1,143 @@ +sentencepiece==0.1.91 +toolz==0.11.1 +en-core-web-sm==https://files.momodel.cn/en_core_web_sm-2.3.0.tar.gz +botocore==1.19.25 +openpyxl==2.6.4 +google-auth-oauthlib==0.4.3 +jsonpatch==1.32 +paddlepaddle==2.0.1 +tensorboard-plugin-wit==1.8.0 +gym==0.17.2 +gensim==3.8.3 +dm-tree==0.1.6 +tqdm==4.46.1 +pyOpenSSL==20.0.1 +google-auth==1.27.1 +pytorch-transformers==1.2.0 +Cython==0.29.20 +boto3==1.16.25 +plac==1.1.3 +backports.entry-points-selectable==1.1.0 +sympy==1.6.2 +Augmentor==0.2.8 +copulas==0.3.3 +multipledispatch==0.6.0 +visdom==0.1.8.9 +pyasn1==0.4.8 +sacremoses==0.0.45 +cmake==3.21.1 +torchfile==0.1.0 +argon2-cffi==20.1.0 +certipy==0.1.3 +configparser==5.0.2 +jmespath==0.10.0 +unification==0.2.2 +plotly==4.8.1 +opt-einsum==3.3.0 +word2vec==0.11.1 +pycparser==2.20 +metakernel==0.27.5 +defusedxml==0.7.1 +xlrd==1.2.0 +func-timeout==4.3.5 +ipdb==0.13.2 +smart-open==5.1.0 +transformers==4.1.1 +kanren==0.2.3 +graphviz==0.14 +nest-asyncio==1.5.1 +PyAudio==0.2.11 +jieba==0.42.1 +astunparse==1.6.3 +CairoSVG==2.5.2 +XlsxWriter==1.4.3 +tensorflow-model-optimization==0.4.1 +pyasn1-modules==0.2.8 +tinycss2==1.1.0 +mindspore==https://ms-release.obs.cn-north-4.myhuaweicloud.com/1.0.0/MindSpore/cpu/ubuntu_x86/mindspore-1.0.0-cp37-cp37m-linux_x86_64.whl +tokenizers==0.9.4 +yellowbrick==1.1 +matplotlib-inline==0.1.2 +joblib==1.0.1 +pyglet==1.5.0 +tensorflow-addons==0.11.2 +Shapely==1.7.0 +minepy==1.2.4 +PyWavelets==1.1.1 +networkx==2.6.2 +mpmath==1.2.1 +pydot==1.4.1 +semantic-version==2.8.5 +cloudpickle==1.3.0 +cffi==1.14.6 +imgaug==0.4.0 +google-pasta==0.2.0 +jupyterlab-server==0.2.0 +asttokens==2.0.5 +srsly==1.0.5 +svgwrite==1.4.1 +pyrsistent==0.18.0 +attrs==19.3.0 +debugpy==1.4.1 +websocket-client==1.3.1 +dlib==19.22.0 +baytune==0.3.12 +cryptography==3.4.7 +tdqm==0.0.1 +torchnet==0.0.4 +oauthlib==3.1.0 +et-xmlfile==1.1.0 +jsonpointer==2.2 +jupyterlab-pygments==0.1.1 +zipp==3.4.1 +portpicker==1.3.9 +typing-extensions==3.7.4.3 +fire==0.4.0 +scikit-image==0.15.0 +click==8.0.1 +spacy==2.3.2 +pytorch-pretrained-bert==0.6.2 +cssselect2==0.4.1 +imageio==2.8.0 +platformdirs==2.1.0 +retrying==1.3.3 +torchvision==0.5.0+cpu +preshed==3.0.5 +torch==1.4.0+cpu +requests-oauthlib==1.3.0 +easydict==1.9 +install==1.3.4 +blis==0.4.1 +torchtext==0.6.0 +tensorflow-privacy==0.5.2 +wasabi==0.8.2 +cachetools==3.1.1 +tensorboardX==2.0 +minio==5.0.10 +filelock==3.0.12 +nltk==3.5 +imbalanced-learn==0.6.2 +cymem==2.0.5 +async-generator==1.10 +distlib==0.3.2 +murmurhash==1.0.5 +jdcal==1.4.1 +typeguard==2.12.1 +thinc==7.4.1 +regex==2021.8.3 +tensorflow-federated==0.17.0 +nbclient==0.5.0 +catalogue==1.0.0 +packaging==21.0 +tf-slim==1.1.0 +tensorflow-estimator==2.3.0 +importlib-metadata==3.7.2 +pygame==2.0.1 +s3transfer==0.3.3 +cairocffi==1.2.0 +rouge==1.0.0 +numpyencoder==0.3.0 +greenlet==1.1.1 +calysto==1.0.6 +rsa==4.7.2 +wrapt==1.12.1 diff --git a/tang.npz b/tang.npz new file mode 100644 index 0000000..7fe5887 Binary files /dev/null and b/tang.npz differ diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..409a34f --- /dev/null +++ b/utils.py @@ -0,0 +1,90 @@ +# coding:utf8 +import visdom +import torch as t +import time +import torchvision as tv +import numpy as np + + +class Visualizer(): + """ + 封装了visdom的基本操作,但是你仍然可以通过`self.vis.function` + 调用原生的visdom接口 + """ + + def __init__(self, env='default', **kwargs): + import visdom + self.vis = visdom.Visdom(env=env, use_incoming_socket=False, **kwargs) + + # 画的第几个数,相当于横座标 + # 保存(’loss',23) 即loss的第23个点 + self.index = {} + self.log_text = '' + + def reinit(self, env='default', **kwargs): + """ + 修改visdom的配置 + """ + self.vis = visdom.Visdom(env=env,use_incoming_socket=False, **kwargs) + return self + + def plot_many(self, d): + """ + 一次plot多个 + @params d: dict (name,value) i.e. ('loss',0.11) + """ + for k, v in d.items(): + self.plot(k, v) + + def img_many(self, d): + for k, v in d.items(): + self.img(k, v) + + def plot(self, name, y): + """ + self.plot('loss',1.00) + """ + x = self.index.get(name, 0) + self.vis.line(Y=np.array([y]), X=np.array([x]), + win=name, + opts=dict(title=name), + update=None if x == 0 else 'append' + ) + self.index[name] = x + 1 + + def img(self, name, img_): + """ + self.img('input_img',t.Tensor(64,64)) + """ + + if len(img_.size()) < 3: + img_ = img_.cpu().unsqueeze(0) + self.vis.image(img_.cpu(), + win=name, + opts=dict(title=name) + ) + + def img_grid_many(self, d): + for k, v in d.items(): + self.img_grid(k, v) + + def img_grid(self, name, input_3d): + """ + 一个batch的图片转成一个网格图,i.e. input(36,64,64) + 会变成 6*6 的网格图,每个格子大小64*64 + """ + self.img(name, tv.utils.make_grid( + input_3d.cpu()[0].unsqueeze(1).clamp(max=1, min=0))) + + def log(self, info, win='log_text'): + """ + self.log({'loss':1,'lr':0.0001}) + """ + + self.log_text += ('[{time}] {info}
'.format( + time=time.strftime('%m%d_%H%M%S'), + info=info)) + self.vis.text(self.log_text, win=win) + + def __getattr__(self, name): + return getattr(self.vis, name)