SV08
4 years ago
| 16 | 16 | " - ```_README.md```*-----说明文档*\n", |
| 17 | 17 | " - ```app_spec.yml```*-----定义项目的输入输出,为部署服务*\n", |
| 18 | 18 | " - ```coding_here.ipynb```*-----输入并运行代码*" |
| 19 | ] | |
| 20 | }, | |
| 21 | { | |
| 22 | "cell_type": "code", | |
| 23 | "execution_count": null, | |
| 24 | "metadata": {}, | |
| 25 | "outputs": [], | |
| 26 | "source": [ | |
| 27 | "def handle(conf):\n", | |
| 28 | " \"\"\"\n", | |
| 29 | " 该方法是部署之后,其他人调用你的服务时候的处理方法。\n", | |
| 30 | " 请按规范填写参数结构,这样我们就能替你自动生成配置文件,方便其他人的调用。\n", | |
| 31 | " 范例:\n", | |
| 32 | " params['key'] = value # value_type: str # description: some description\n", | |
| 33 | " value_type 可以选择:img, video, audio, str, int, float, [int], [str], [float]\n", | |
| 34 | " 参数请放到params字典中,我们会自动解析该变量。\n", | |
| 35 | " \"\"\"\n", | |
| 36 | "\n", | |
| 37 | " param1 = conf['param1'] # value_type: str # description: some description\n", | |
| 38 | " # add your code\n", | |
| 39 | " return {'ret1': 'cat'}\n", | |
| 40 | " " | |
| 41 | ] | |
| 42 | }, | |
| 43 | { | |
| 44 | "cell_type": "code", | |
| 45 | "execution_count": null, | |
| 46 | "metadata": {}, | |
| 47 | "outputs": [], | |
| 48 | "source": [ | |
| 49 | "def handle(conf):\n", | |
| 50 | " \"\"\"\n", | |
| 51 | " 该方法是部署之后,其他人调用你的服务时候的处理方法。\n", | |
| 52 | " 请按规范填写参数结构,这样我们就能替你自动生成配置文件,方便其他人的调用。\n", | |
| 53 | " 范例:\n", | |
| 54 | " params['key'] = value # value_type: str # description: some description\n", | |
| 55 | " value_type 可以选择:img, video, audio, str, int, float, [int], [str], [float]\n", | |
| 56 | " 参数请放到params字典中,我们会自动解析该变量。\n", | |
| 57 | " \"\"\"\n", | |
| 58 | "\n", | |
| 59 | " param1 = conf['param1'] # value_type: str # description: some description\n", | |
| 60 | " # add your code\n", | |
| 61 | " return {'ret1': 'cat'}\n", | |
| 62 | " " | |
| 19 | 63 | ] |
| 20 | 64 | }, |
| 21 | 65 | { |
| 134 | 178 | "name": "python", |
| 135 | 179 | "nbconvert_exporter": "python", |
| 136 | 180 | "pygments_lexer": "ipython3", |
| 137 | "version": "3.5.2" | |
| 138 | }, | |
| 181 | "version": "3.7.5" | |
| 182 | }, | |
| 139 | 183 | "pycharm": { |
| 140 | 184 | "stem_cell": { |
| 141 | 185 | "cell_type": "raw", |
| 142 | "source": [], | |
| 143 | 186 | "metadata": { |
| 144 | 187 | "collapsed": false |
| 145 | } | |
| 146 | } | |
| 147 | } | |
| 188 | }, | |
| 189 | "source": [] | |
| 190 | } | |
| 191 | } | |
| 148 | 192 | }, |
| 149 | 193 | "nbformat": 4, |
| 150 | 194 | "nbformat_minor": 2 |
| 0 | input: | |
| 1 | start_words: | |
| 2 | name: start_words | |
| 3 | value_type: str | |
| 4 | description: 想要开始的第一个字 | |
| 5 | prefix_words: | |
| 6 | name: prefix_words | |
| 7 | value_type: str | |
| 8 | description: 想要的诗歌语境 | |
| 9 | max_gen_len: | |
| 10 | name: max_gen_len | |
| 11 | value_type: int | |
| 12 | description: 想要的诗歌长度 | |
| 13 | output: | |
| 14 | result: | |
| 15 | name: result | |
| 16 | value_type: str | |
| 17 | description: 你的诗歌 |
Binary diff not shown
| 0 | ||
| 1 | { | |
| 2 | "cells": [ | |
| 3 | { | |
| 4 | "cell_type": "code", | |
| 5 | "execution_count": null, | |
| 6 | "metadata": {}, | |
| 7 | "outputs": [], | |
| 8 | "source": [ | |
| 9 | "print('Hello Mo!')" | |
| 10 | ] | |
| 11 | } | |
| 12 | ], | |
| 13 | "metadata": { | |
| 14 | "kernelspec": { | |
| 15 | "display_name": "Python 3", | |
| 16 | "language": "python", | |
| 17 | "name": "python3" | |
| 18 | }, | |
| 19 | "language_info": { | |
| 20 | "codemirror_mode": { | |
| 21 | "name": "ipython", | |
| 22 | "version": 3 | |
| 23 | }, | |
| 24 | "file_extension": ".py", | |
| 25 | "mimetype": "text/x-python", | |
| 26 | "name": "python", | |
| 27 | "nbconvert_exporter": "python", | |
| 28 | "pygments_lexer": "ipython3", | |
| 29 | "version": "3.5.2" | |
| 30 | } | |
| 31 | }, | |
| 32 | "nbformat": 4, | |
| 33 | "nbformat_minor": 2 | |
| 34 | } | |
| 35 | ⏎ | |
| 0 | { | |
| 1 | "cells": [ | |
| 2 | { | |
| 3 | "cell_type": "code", | |
| 4 | "execution_count": null, | |
| 5 | "metadata": {}, | |
| 6 | "outputs": [], | |
| 7 | "source": [ | |
| 8 | "def handle(conf):\n", | |
| 9 | " \"\"\"\n", | |
| 10 | " 该方法是部署之后,其他人调用你的服务时候的处理方法。\n", | |
| 11 | " 请按规范填写参数结构,这样我们就能替你自动生成配置文件,方便其他人的调用。\n", | |
| 12 | " 范例:\n", | |
| 13 | " params['key'] = value # value_type: str # description: some description\n", | |
| 14 | " value_type 可以选择:img, video, audio, str, int, float, [int], [str], [float]\n", | |
| 15 | " 参数请放到params字典中,我们会自动解析该变量。\n", | |
| 16 | " \"\"\"\n", | |
| 17 | "\n", | |
| 18 | "# param1 = conf['param1'] # value_type: str # description: some description\n", | |
| 19 | " start_words = conf['start_words'] #诗歌开始\n", | |
| 20 | " prefix_words = conf['prefix_words'] #诗歌语境\n", | |
| 21 | " max_gen_len = conf['max_gen_len'] #诗歌最大长度\n", | |
| 22 | " # add your code\n", | |
| 23 | " if __name__ == '__main__':\n", | |
| 24 | " result = main.gen(start_words,prefix_words,max_gen_len)\n", | |
| 25 | " result = ''.json(result)\n", | |
| 26 | " return {'你的诗歌': 'result'}\n", | |
| 27 | " " | |
| 28 | ] | |
| 29 | }, | |
| 30 | { | |
| 31 | "cell_type": "code", | |
| 32 | "execution_count": null, | |
| 33 | "metadata": {}, | |
| 34 | "outputs": [], | |
| 35 | "source": [] | |
| 36 | } | |
| 37 | ], | |
| 38 | "metadata": { | |
| 39 | "kernelspec": { | |
| 40 | "display_name": "Python 3", | |
| 41 | "language": "python", | |
| 42 | "name": "python3" | |
| 43 | }, | |
| 44 | "language_info": { | |
| 45 | "codemirror_mode": { | |
| 46 | "name": "ipython", | |
| 47 | "version": 3 | |
| 48 | }, | |
| 49 | "file_extension": ".py", | |
| 50 | "mimetype": "text/x-python", | |
| 51 | "name": "python", | |
| 52 | "nbconvert_exporter": "python", | |
| 53 | "pygments_lexer": "ipython3", | |
| 54 | "version": "3.7.5" | |
| 55 | } | |
| 56 | }, | |
| 57 | "nbformat": 4, | |
| 58 | "nbformat_minor": 2 | |
| 59 | } |
| 0 | # coding:utf-8 | |
| 1 | import sys | |
| 2 | import os | |
| 3 | import json | |
| 4 | import re | |
| 5 | import numpy as np | |
| 6 | ||
| 7 | ||
| 8 | def _parseRawData(author=None, constrain=None, src='./chinese-poetry/json/simplified', category="poet.tang"): | |
| 9 | """ | |
| 10 | code from https://github.com/justdark/pytorch-poetry-gen/blob/master/dataHandler.py | |
| 11 | 处理json文件,返回诗歌内容 | |
| 12 | @param: author: 作者名字 | |
| 13 | @param: constrain: 长度限制 | |
| 14 | @param: src: json 文件存放路径 | |
| 15 | @param: category: 类别,有poet.song 和 poet.tang | |
| 16 | ||
| 17 | 返回 data:list | |
| 18 | ['床前明月光,疑是地上霜,举头望明月,低头思故乡。', | |
| 19 | '一去二三里,烟村四五家,亭台六七座,八九十支花。', | |
| 20 | ......... | |
| 21 | ] | |
| 22 | """ | |
| 23 | ||
| 24 | def sentenceParse(para): | |
| 25 | # para 形如 "-181-村橋路不端,數里就迴湍。積壤連涇脉,高林上笋竿。早嘗甘蔗淡, | |
| 26 | # 生摘琵琶酸。(「琵琶」,嚴壽澄校《張祜詩集》云:疑「枇杷」之誤。) | |
| 27 | # 好是去塵俗,煙花長一欄。" | |
| 28 | result, number = re.subn(u"(.*)", "", para) | |
| 29 | result, number = re.subn(u"{.*}", "", result) | |
| 30 | result, number = re.subn(u"《.*》", "", result) | |
| 31 | result, number = re.subn(u"《.*》", "", result) | |
| 32 | result, number = re.subn(u"[\]\[]", "", result) | |
| 33 | r = "" | |
| 34 | for s in result: | |
| 35 | if s not in set('0123456789-'): | |
| 36 | r += s | |
| 37 | r, number = re.subn(u"。。", u"。", r) | |
| 38 | return r | |
| 39 | ||
| 40 | def handleJson(file): | |
| 41 | # print file | |
| 42 | rst = [] | |
| 43 | data = json.loads(open(file).read()) | |
| 44 | for poetry in data: | |
| 45 | pdata = "" | |
| 46 | if (author is not None and poetry.get("author") != author): | |
| 47 | continue | |
| 48 | p = poetry.get("paragraphs") | |
| 49 | flag = False | |
| 50 | for s in p: | |
| 51 | sp = re.split(u"[,!。]", s) | |
| 52 | for tr in sp: | |
| 53 | if constrain is not None and len(tr) != constrain and len(tr) != 0: | |
| 54 | flag = True | |
| 55 | break | |
| 56 | if flag: | |
| 57 | break | |
| 58 | if flag: | |
| 59 | continue | |
| 60 | for sentence in poetry.get("paragraphs"): | |
| 61 | pdata += sentence | |
| 62 | pdata = sentenceParse(pdata) | |
| 63 | if pdata != "": | |
| 64 | rst.append(pdata) | |
| 65 | return rst | |
| 66 | ||
| 67 | data = [] | |
| 68 | for filename in os.listdir(src): | |
| 69 | if filename.startswith(category): | |
| 70 | data.extend(handleJson(src + filename)) | |
| 71 | return data | |
| 72 | ||
| 73 | ||
| 74 | def pad_sequences(sequences, | |
| 75 | maxlen=None, | |
| 76 | dtype='int32', | |
| 77 | padding='pre', | |
| 78 | truncating='pre', | |
| 79 | value=0.): | |
| 80 | """ | |
| 81 | code from keras | |
| 82 | Pads each sequence to the same length (length of the longest sequence). | |
| 83 | If maxlen is provided, any sequence longer | |
| 84 | than maxlen is truncated to maxlen. | |
| 85 | Truncation happens off either the beginning (default) or | |
| 86 | the end of the sequence. | |
| 87 | Supports post-padding and pre-padding (default). | |
| 88 | Arguments: | |
| 89 | sequences: list of lists where each element is a sequence | |
| 90 | maxlen: int, maximum length | |
| 91 | dtype: type to cast the resulting sequence. | |
| 92 | padding: 'pre' or 'post', pad either before or after each sequence. | |
| 93 | truncating: 'pre' or 'post', remove values from sequences larger than | |
| 94 | maxlen either in the beginning or in the end of the sequence | |
| 95 | value: float, value to pad the sequences to the desired value. | |
| 96 | Returns: | |
| 97 | x: numpy array with dimensions (number_of_sequences, maxlen) | |
| 98 | Raises: | |
| 99 | ValueError: in case of invalid values for `truncating` or `padding`, | |
| 100 | or in case of invalid shape for a `sequences` entry. | |
| 101 | """ | |
| 102 | if not hasattr(sequences, '__len__'): | |
| 103 | raise ValueError('`sequences` must be iterable.') | |
| 104 | lengths = [] | |
| 105 | for x in sequences: | |
| 106 | if not hasattr(x, '__len__'): | |
| 107 | raise ValueError('`sequences` must be a list of iterables. ' | |
| 108 | 'Found non-iterable: ' + str(x)) | |
| 109 | lengths.append(len(x)) | |
| 110 | ||
| 111 | num_samples = len(sequences) | |
| 112 | if maxlen is None: | |
| 113 | maxlen = np.max(lengths) | |
| 114 | ||
| 115 | # take the sample shape from the first non empty sequence | |
| 116 | # checking for consistency in the main loop below. | |
| 117 | sample_shape = tuple() | |
| 118 | for s in sequences: | |
| 119 | if len(s) > 0: # pylint: disable=g-explicit-length-test | |
| 120 | sample_shape = np.asarray(s).shape[1:] | |
| 121 | break | |
| 122 | ||
| 123 | x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype) | |
| 124 | for idx, s in enumerate(sequences): | |
| 125 | if not len(s): # pylint: disable=g-explicit-length-test | |
| 126 | continue # empty list/array was found | |
| 127 | if truncating == 'pre': | |
| 128 | trunc = s[-maxlen:] # pylint: disable=invalid-unary-operand-type | |
| 129 | elif truncating == 'post': | |
| 130 | trunc = s[:maxlen] | |
| 131 | else: | |
| 132 | raise ValueError('Truncating type "%s" not understood' % truncating) | |
| 133 | ||
| 134 | # check `trunc` has expected shape | |
| 135 | trunc = np.asarray(trunc, dtype=dtype) | |
| 136 | if trunc.shape[1:] != sample_shape: | |
| 137 | raise ValueError( | |
| 138 | 'Shape of sample %s of sequence at position %s is different from ' | |
| 139 | 'expected shape %s' | |
| 140 | % (trunc.shape[1:], idx, sample_shape)) | |
| 141 | ||
| 142 | if padding == 'post': | |
| 143 | x[idx, :len(trunc)] = trunc | |
| 144 | elif padding == 'pre': | |
| 145 | x[idx, -len(trunc):] = trunc | |
| 146 | else: | |
| 147 | raise ValueError('Padding type "%s" not understood' % padding) | |
| 148 | return x | |
| 149 | ||
| 150 | ||
| 151 | def get_data(opt): | |
| 152 | """ | |
| 153 | @param opt 配置选项 Config对象 | |
| 154 | @return word2ix: dict,每个字对应的序号,形如u'月'->100 | |
| 155 | @return ix2word: dict,每个序号对应的字,形如'100'->u'月' | |
| 156 | @return data: numpy数组,每一行是一首诗对应的字的下标 | |
| 157 | """ | |
| 158 | if os.path.exists(opt.pickle_path): | |
| 159 | data = np.load(opt.pickle_path, allow_pickle=True) | |
| 160 | data, word2ix, ix2word = data['data'], data['word2ix'].item(), data['ix2word'].item() | |
| 161 | return data, word2ix, ix2word | |
| 162 | ||
| 163 | # 如果没有处理好的二进制文件,则处理原始的json文件 | |
| 164 | data = _parseRawData(opt.author, opt.constrain, opt.data_path, opt.category) | |
| 165 | words = {_word for _sentence in data for _word in _sentence} | |
| 166 | word2ix = {_word: _ix for _ix, _word in enumerate(words)} | |
| 167 | word2ix['<EOP>'] = len(word2ix) # 终止标识符 | |
| 168 | word2ix['<START>'] = len(word2ix) # 起始标识符 | |
| 169 | word2ix['</s>'] = len(word2ix) # 空格 | |
| 170 | ix2word = {_ix: _word for _word, _ix in list(word2ix.items())} | |
| 171 | ||
| 172 | # 为每首诗歌加上起始符和终止符 | |
| 173 | for i in range(len(data)): | |
| 174 | data[i] = ["<START>"] + list(data[i]) + ["<EOP>"] | |
| 175 | ||
| 176 | # 将每首诗歌保存的内容由‘字’变成‘数’ | |
| 177 | # 形如[春,江,花,月,夜]变成[1,2,3,4,5] | |
| 178 | new_data = [[word2ix[_word] for _word in _sentence] | |
| 179 | for _sentence in data] | |
| 180 | ||
| 181 | # 诗歌长度不够opt.maxlen的在前面补空格,超过的,删除末尾的 | |
| 182 | pad_data = pad_sequences(new_data, | |
| 183 | maxlen=opt.maxlen, | |
| 184 | padding='pre', | |
| 185 | truncating='post', | |
| 186 | value=len(word2ix) - 1) | |
| 187 | ||
| 188 | # 保存成二进制文件 | |
| 189 | np.savez_compressed(opt.pickle_path, | |
| 190 | data=pad_data, | |
| 191 | word2ix=word2ix, | |
| 192 | ix2word=ix2word) | |
| 193 | return pad_data, word2ix, ix2word |
| 0 | def handle(conf): | |
| 1 | """ | |
| 2 | 该方法是部署之后,其他人调用你的服务时候的处理方法。 | |
| 3 | 请按规范填写参数结构,这样我们就能替你自动生成配置文件,方便其他人的调用。 | |
| 4 | 范例: | |
| 5 | params['key'] = value # value_type: str # description: some description | |
| 6 | value_type 可以选择:img, video, audio, str, int, float, [int], [str], [float] | |
| 7 | 参数请放到params字典中,我们会自动解析该变量。 | |
| 8 | """ | |
| 9 | ||
| 10 | # param1 = conf['param1'] # value_type: str # description: some description | |
| 11 | start_words = conf['start_words'] #诗歌开始 | |
| 12 | prefix_words = conf['prefix_words'] #诗歌语境 | |
| 13 | max_gen_len = conf['max_gen_len'] #诗歌最大长度 | |
| 14 | # add your code | |
| 15 | if __name__ == '__main__': | |
| 16 | result = main.gen(start_words,prefix_words,max_gen_len) | |
| 17 | result = ''.json(result) | |
| 18 | return {'你的诗歌': 'result'} | |
| 19 | ⏎ |
| 0 | 2022-03-30T01:36:55.737591372Z SYSTEM: Preparing env... | |
| 1 | 2022-03-30T01:36:57.042653704Z SYSTEM: Running... | |
| 2 | 2022-03-30T01:37:00.202120599Z Traceback (most recent call last): | |
| 3 | 2022-03-30T01:37:00.202162227Z File "main.py", line 4, in <module> | |
| 4 | 2022-03-30T01:37:00.202308379Z from data import get_data | |
| 5 | 2022-03-30T01:37:00.202332663Z ModuleNotFoundError: No module named 'data' | |
| 6 | 2022-03-30T01:37:00.318630589Z SYSTEM: Finishing... |
| 0 | 2022-03-30T01:34:08.011406308Z SYSTEM: Preparing env... | |
| 1 | 2022-03-30T01:34:08.572530654Z SYSTEM: Running... | |
| 2 | 2022-03-30T01:34:09.509088239Z Traceback (most recent call last): | |
| 3 | 2022-03-30T01:34:09.509150878Z File "main.py", line 4, in <module> | |
| 4 | 2022-03-30T01:34:09.514459087Z from data import get_data | |
| 5 | 2022-03-30T01:34:09.514480993Z ModuleNotFoundError: No module named 'data' | |
| 6 | 2022-03-30T01:34:09.612018351Z SYSTEM: Finishing... |
| 0 | # coding:utf8 | |
| 1 | import sys, os | |
| 2 | import torch as t | |
| 3 | from data import get_data | |
| 4 | from model import PoetryModel | |
| 5 | from torch import nn | |
| 6 | from utils import Visualizer | |
| 7 | import tqdm | |
| 8 | from torchnet import meter | |
| 9 | import ipdb | |
| 10 | ||
| 11 | ||
| 12 | class Config(object): | |
| 13 | data_path = 'tang.npz' # 诗歌的文本文件存放路径 | |
| 14 | pickle_path = 'tang.npz' # 预处理好的二进制文件 | |
| 15 | author = None # 只学习某位作者的诗歌 | |
| 16 | constrain = None # 长度限制 | |
| 17 | category = 'poet.tang' # 类别,唐诗还是宋诗歌(poet.song) | |
| 18 | lr = 1e-3 | |
| 19 | weight_decay = 1e-4 | |
| 20 | use_gpu = False | |
| 21 | epoch = 20 | |
| 22 | batch_size = 128 | |
| 23 | maxlen = 125 # 超过这个长度的之后字被丢弃,小于这个长度的在前面补空格 | |
| 24 | plot_every = 20 # 每20个batch 可视化一次 | |
| 25 | # use_env = True # 是否使用visodm | |
| 26 | env = 'poetry' # visdom env | |
| 27 | max_gen_len = 200 # 生成诗歌最长长度 | |
| 28 | debug_file = '/tmp/debugp' | |
| 29 | model_path = 'checkpoints/tang_199.pth' # 预训练模型路径 | |
| 30 | prefix_words = '美女' # 不是诗歌的组成部分,用来控制生成诗歌的意境 | |
| 31 | start_words = '雨' # 诗歌开始 | |
| 32 | acrostic = False # 是否是藏头诗 | |
| 33 | model_prefix = 'checkpoints/tang' # 模型保存路径 | |
| 34 | ||
| 35 | ||
| 36 | opt = Config() | |
| 37 | ||
| 38 | ||
| 39 | def generate(model, start_words, ix2word, word2ix, prefix_words=None): | |
| 40 | """ | |
| 41 | 给定几个词,根据这几个词接着生成一首完整的诗歌 | |
| 42 | start_words:u'春江潮水连海平' | |
| 43 | 比如start_words 为 春江潮水连海平,可以生成: | |
| 44 | ||
| 45 | """ | |
| 46 | ||
| 47 | results = list(start_words) | |
| 48 | start_word_len = len(start_words) | |
| 49 | # 手动设置第一个词为<START> | |
| 50 | input = t.Tensor([word2ix['<START>']]).view(1, 1).long() | |
| 51 | if opt.use_gpu: input = input.cuda() | |
| 52 | hidden = None | |
| 53 | ||
| 54 | if prefix_words: | |
| 55 | for word in prefix_words: | |
| 56 | output, hidden = model(input, hidden) | |
| 57 | # print(input.data) | |
| 58 | # input = input.data.new([word2ix[word]]).view(1, 1) | |
| 59 | ||
| 60 | for i in range(opt.max_gen_len): | |
| 61 | output, hidden = model(input, hidden) | |
| 62 | ||
| 63 | if i < start_word_len: | |
| 64 | w = results[i] | |
| 65 | input = input.data.new([word2ix[w]]).view(1, 1) | |
| 66 | else: | |
| 67 | top_index = output.data[0].topk(1)[1][0].item() | |
| 68 | w = ix2word[top_index] | |
| 69 | results.append(w) | |
| 70 | input = input.data.new([top_index]).view(1, 1) | |
| 71 | if w == '<EOP>': | |
| 72 | del results[-1] | |
| 73 | break | |
| 74 | return results | |
| 75 | ||
| 76 | ||
| 77 | def gen_acrostic(model, start_words, ix2word, word2ix, prefix_words=None): | |
| 78 | """ | |
| 79 | 生成藏头诗 | |
| 80 | start_words : u'深度学习' | |
| 81 | 生成: | |
| 82 | 深木通中岳,青苔半日脂。 | |
| 83 | 度山分地险,逆浪到南巴。 | |
| 84 | 学道兵犹毒,当时燕不移。 | |
| 85 | 习根通古岸,开镜出清羸。 | |
| 86 | """ | |
| 87 | results = [] | |
| 88 | start_word_len = len(start_words) | |
| 89 | input = (t.Tensor([word2ix['<START>']]).view(1, 1).long()) | |
| 90 | if opt.use_gpu: input = input.cuda() | |
| 91 | hidden = None | |
| 92 | ||
| 93 | index = 0 # 用来指示已经生成了多少句藏头诗 | |
| 94 | # 上一个词 | |
| 95 | pre_word = '<START>' | |
| 96 | ||
| 97 | if prefix_words: | |
| 98 | for word in prefix_words: | |
| 99 | output, hidden = model(input, hidden) | |
| 100 | input = (input.data.new([word2ix[word]])).view(1, 1) | |
| 101 | ||
| 102 | for i in range(opt.max_gen_len): | |
| 103 | output, hidden = model(input, hidden) | |
| 104 | top_index = output.data[0].topk(1)[1][0].item() | |
| 105 | w = ix2word[top_index] | |
| 106 | ||
| 107 | if (pre_word in {u'。', u'!', '<START>'}): | |
| 108 | # 如果遇到句号,藏头的词送进去生成 | |
| 109 | ||
| 110 | if index == start_word_len: | |
| 111 | # 如果生成的诗歌已经包含全部藏头的词,则结束 | |
| 112 | break | |
| 113 | else: | |
| 114 | # 把藏头的词作为输入送入模型 | |
| 115 | w = start_words[index] | |
| 116 | index += 1 | |
| 117 | input = (input.data.new([word2ix[w]])).view(1, 1) | |
| 118 | else: | |
| 119 | # 否则的话,把上一次预测是词作为下一个词输入 | |
| 120 | input = (input.data.new([word2ix[w]])).view(1, 1) | |
| 121 | results.append(w) | |
| 122 | pre_word = w | |
| 123 | return results | |
| 124 | ||
| 125 | ||
| 126 | def train(**kwargs): | |
| 127 | for k, v in kwargs.items(): | |
| 128 | setattr(opt, k, v) | |
| 129 | ||
| 130 | opt.device=t.device('cuda') if opt.use_gpu else t.device('cpu') | |
| 131 | device = opt.device | |
| 132 | vis = Visualizer(env=opt.env) | |
| 133 | ||
| 134 | # 获取数据 | |
| 135 | data, word2ix, ix2word = get_data(opt) | |
| 136 | data = t.from_numpy(data) | |
| 137 | dataloader = t.utils.data.DataLoader(data, | |
| 138 | batch_size=opt.batch_size, | |
| 139 | shuffle=True, | |
| 140 | num_workers=1) | |
| 141 | ||
| 142 | # 模型定义 | |
| 143 | model = PoetryModel(len(word2ix), 128, 256) | |
| 144 | optimizer = t.optim.Adam(model.parameters(), lr=opt.lr) | |
| 145 | criterion = nn.CrossEntropyLoss() | |
| 146 | if opt.model_path: | |
| 147 | model.load_state_dict(t.load(opt.model_path)) | |
| 148 | model.to(device) | |
| 149 | ||
| 150 | loss_meter = meter.AverageValueMeter() | |
| 151 | for epoch in range(opt.epoch): | |
| 152 | loss_meter.reset() | |
| 153 | for ii, data_ in tqdm.tqdm(enumerate(dataloader)): | |
| 154 | ||
| 155 | # 训练 | |
| 156 | data_ = data_.long().transpose(1, 0).contiguous() | |
| 157 | data_ = data_.to(device) | |
| 158 | optimizer.zero_grad() | |
| 159 | input_, target = data_[:-1, :], data_[1:, :] | |
| 160 | output, _ = model(input_) | |
| 161 | loss = criterion(output, target.view(-1)) | |
| 162 | loss.backward() | |
| 163 | optimizer.step() | |
| 164 | ||
| 165 | loss_meter.add(loss.item()) | |
| 166 | ||
| 167 | # 可视化 | |
| 168 | if (1 + ii) % opt.plot_every == 0: | |
| 169 | ||
| 170 | if os.path.exists(opt.debug_file): | |
| 171 | ipdb.set_trace() | |
| 172 | ||
| 173 | vis.plot('loss', loss_meter.value()[0]) | |
| 174 | ||
| 175 | # 诗歌原文 | |
| 176 | poetrys = [[ix2word[_word] for _word in data_[:, _iii].tolist()] | |
| 177 | for _iii in range(data_.shape[1])][:16] | |
| 178 | vis.text('</br>'.join([''.join(poetry) for poetry in poetrys]), win=u'origin_poem') | |
| 179 | ||
| 180 | gen_poetries = [] | |
| 181 | # 分别以这几个字作为诗歌的第一个字,生成8首诗 | |
| 182 | for word in list(u'春江花月夜凉如水'): | |
| 183 | gen_poetry = ''.join(generate(model, word, ix2word, word2ix)) | |
| 184 | gen_poetries.append(gen_poetry) | |
| 185 | vis.text('</br>'.join([''.join(poetry) for poetry in gen_poetries]), win=u'gen_poem') | |
| 186 | ||
| 187 | t.save(model.state_dict(), '%s_%s.pth' % (opt.model_prefix, epoch)) | |
| 188 | ||
| 189 | ||
| 190 | def gen(**kwargs): | |
| 191 | """ | |
| 192 | 提供命令行接口,用以生成相应的诗 | |
| 193 | """ | |
| 194 | ||
| 195 | for k, v in kwargs.items(): | |
| 196 | setattr(opt, k, v) | |
| 197 | data, word2ix, ix2word = get_data(opt) | |
| 198 | model = PoetryModel(len(word2ix), 128, 256); | |
| 199 | map_location = lambda s, l: s | |
| 200 | state_dict = t.load(opt.model_path, map_location=map_location) | |
| 201 | model.load_state_dict(state_dict) | |
| 202 | ||
| 203 | if opt.use_gpu: | |
| 204 | model.cuda() | |
| 205 | ||
| 206 | # python2和python3 字符串兼容 | |
| 207 | if sys.version_info.major == 3: | |
| 208 | if opt.start_words.isprintable(): | |
| 209 | start_words = opt.start_words | |
| 210 | prefix_words = opt.prefix_words if opt.prefix_words else None | |
| 211 | else: | |
| 212 | start_words = opt.start_words.encode('ascii', 'surrogateescape').decode('utf8') | |
| 213 | prefix_words = opt.prefix_words.encode('ascii', 'surrogateescape').decode( | |
| 214 | 'utf8') if opt.prefix_words else None | |
| 215 | else: | |
| 216 | start_words = opt.start_words.decode('utf8') | |
| 217 | prefix_words = opt.prefix_words.decode('utf8') if opt.prefix_words else None | |
| 218 | ||
| 219 | start_words = start_words.replace(',', u',') \ | |
| 220 | .replace('.', u'。') \ | |
| 221 | .replace('?', u'?') | |
| 222 | ||
| 223 | gen_poetry = gen_acrostic if opt.acrostic else generate | |
| 224 | result = gen_poetry(model, start_words, ix2word, word2ix, prefix_words) | |
| 225 | return result | |
| 226 | # print(''.join(result)) | |
| 227 | ||
| 228 | ||
| 229 | if __name__ == '__main__': | |
| 230 | # import fire | |
| 231 | # | |
| 232 | # fire.Fire() | |
| 233 | gen() |
| 0 | # coding:utf8 | |
| 1 | import torch | |
| 2 | import torch.nn as nn | |
| 3 | import torch.nn.functional as F | |
| 4 | ||
| 5 | ||
| 6 | class PoetryModel(nn.Module): | |
| 7 | def __init__(self, vocab_size, embedding_dim, hidden_dim): | |
| 8 | super(PoetryModel, self).__init__() | |
| 9 | self.hidden_dim = hidden_dim | |
| 10 | self.embeddings = nn.Embedding(vocab_size, embedding_dim) | |
| 11 | self.lstm = nn.LSTM(embedding_dim, self.hidden_dim, num_layers=2) | |
| 12 | self.linear1 = nn.Linear(self.hidden_dim, vocab_size) | |
| 13 | ||
| 14 | def forward(self, input, hidden=None): | |
| 15 | seq_len, batch_size = input.size() | |
| 16 | if hidden is None: | |
| 17 | # h_0 = 0.01*torch.Tensor(2, batch_size, self.hidden_dim).normal_().cuda() | |
| 18 | # c_0 = 0.01*torch.Tensor(2, batch_size, self.hidden_dim).normal_().cuda() | |
| 19 | h_0 = input.data.new(2, batch_size, self.hidden_dim).fill_(0).float() | |
| 20 | c_0 = input.data.new(2, batch_size, self.hidden_dim).fill_(0).float() | |
| 21 | else: | |
| 22 | h_0, c_0 = hidden | |
| 23 | # size: (seq_len,batch_size,embeding_dim) | |
| 24 | embeds = self.embeddings(input) | |
| 25 | # output size: (seq_len,batch_size,hidden_dim) | |
| 26 | output, hidden = self.lstm(embeds, (h_0, c_0)) | |
| 27 | ||
| 28 | # size: (seq_len*batch_size,vocab_size) | |
| 29 | output = self.linear1(output.view(seq_len * batch_size, -1)) | |
| 30 | return output, hidden |
| 0 | sentencepiece==0.1.91 | |
| 1 | toolz==0.11.1 | |
| 2 | en-core-web-sm==https://files.momodel.cn/en_core_web_sm-2.3.0.tar.gz | |
| 3 | botocore==1.19.25 | |
| 4 | openpyxl==2.6.4 | |
| 5 | google-auth-oauthlib==0.4.3 | |
| 6 | jsonpatch==1.32 | |
| 7 | paddlepaddle==2.0.1 | |
| 8 | tensorboard-plugin-wit==1.8.0 | |
| 9 | gym==0.17.2 | |
| 10 | gensim==3.8.3 | |
| 11 | dm-tree==0.1.6 | |
| 12 | tqdm==4.46.1 | |
| 13 | pyOpenSSL==20.0.1 | |
| 14 | google-auth==1.27.1 | |
| 15 | pytorch-transformers==1.2.0 | |
| 16 | Cython==0.29.20 | |
| 17 | boto3==1.16.25 | |
| 18 | plac==1.1.3 | |
| 19 | backports.entry-points-selectable==1.1.0 | |
| 20 | sympy==1.6.2 | |
| 21 | Augmentor==0.2.8 | |
| 22 | copulas==0.3.3 | |
| 23 | multipledispatch==0.6.0 | |
| 24 | visdom==0.1.8.9 | |
| 25 | pyasn1==0.4.8 | |
| 26 | sacremoses==0.0.45 | |
| 27 | cmake==3.21.1 | |
| 28 | torchfile==0.1.0 | |
| 29 | argon2-cffi==20.1.0 | |
| 30 | certipy==0.1.3 | |
| 31 | configparser==5.0.2 | |
| 32 | jmespath==0.10.0 | |
| 33 | unification==0.2.2 | |
| 34 | plotly==4.8.1 | |
| 35 | opt-einsum==3.3.0 | |
| 36 | word2vec==0.11.1 | |
| 37 | pycparser==2.20 | |
| 38 | metakernel==0.27.5 | |
| 39 | defusedxml==0.7.1 | |
| 40 | xlrd==1.2.0 | |
| 41 | func-timeout==4.3.5 | |
| 42 | ipdb==0.13.2 | |
| 43 | smart-open==5.1.0 | |
| 44 | transformers==4.1.1 | |
| 45 | kanren==0.2.3 | |
| 46 | graphviz==0.14 | |
| 47 | nest-asyncio==1.5.1 | |
| 48 | PyAudio==0.2.11 | |
| 49 | jieba==0.42.1 | |
| 50 | astunparse==1.6.3 | |
| 51 | CairoSVG==2.5.2 | |
| 52 | XlsxWriter==1.4.3 | |
| 53 | tensorflow-model-optimization==0.4.1 | |
| 54 | pyasn1-modules==0.2.8 | |
| 55 | tinycss2==1.1.0 | |
| 56 | mindspore==https://ms-release.obs.cn-north-4.myhuaweicloud.com/1.0.0/MindSpore/cpu/ubuntu_x86/mindspore-1.0.0-cp37-cp37m-linux_x86_64.whl | |
| 57 | tokenizers==0.9.4 | |
| 58 | yellowbrick==1.1 | |
| 59 | matplotlib-inline==0.1.2 | |
| 60 | joblib==1.0.1 | |
| 61 | pyglet==1.5.0 | |
| 62 | tensorflow-addons==0.11.2 | |
| 63 | Shapely==1.7.0 | |
| 64 | minepy==1.2.4 | |
| 65 | PyWavelets==1.1.1 | |
| 66 | networkx==2.6.2 | |
| 67 | mpmath==1.2.1 | |
| 68 | pydot==1.4.1 | |
| 69 | semantic-version==2.8.5 | |
| 70 | cloudpickle==1.3.0 | |
| 71 | cffi==1.14.6 | |
| 72 | imgaug==0.4.0 | |
| 73 | google-pasta==0.2.0 | |
| 74 | jupyterlab-server==0.2.0 | |
| 75 | asttokens==2.0.5 | |
| 76 | srsly==1.0.5 | |
| 77 | svgwrite==1.4.1 | |
| 78 | pyrsistent==0.18.0 | |
| 79 | attrs==19.3.0 | |
| 80 | debugpy==1.4.1 | |
| 81 | websocket-client==1.3.1 | |
| 82 | dlib==19.22.0 | |
| 83 | baytune==0.3.12 | |
| 84 | cryptography==3.4.7 | |
| 85 | tdqm==0.0.1 | |
| 86 | torchnet==0.0.4 | |
| 87 | oauthlib==3.1.0 | |
| 88 | et-xmlfile==1.1.0 | |
| 89 | jsonpointer==2.2 | |
| 90 | jupyterlab-pygments==0.1.1 | |
| 91 | zipp==3.4.1 | |
| 92 | portpicker==1.3.9 | |
| 93 | typing-extensions==3.7.4.3 | |
| 94 | fire==0.4.0 | |
| 95 | scikit-image==0.15.0 | |
| 96 | click==8.0.1 | |
| 97 | spacy==2.3.2 | |
| 98 | pytorch-pretrained-bert==0.6.2 | |
| 99 | cssselect2==0.4.1 | |
| 100 | imageio==2.8.0 | |
| 101 | platformdirs==2.1.0 | |
| 102 | retrying==1.3.3 | |
| 103 | torchvision==0.5.0+cpu | |
| 104 | preshed==3.0.5 | |
| 105 | torch==1.4.0+cpu | |
| 106 | requests-oauthlib==1.3.0 | |
| 107 | easydict==1.9 | |
| 108 | install==1.3.4 | |
| 109 | blis==0.4.1 | |
| 110 | torchtext==0.6.0 | |
| 111 | tensorflow-privacy==0.5.2 | |
| 112 | wasabi==0.8.2 | |
| 113 | cachetools==3.1.1 | |
| 114 | tensorboardX==2.0 | |
| 115 | minio==5.0.10 | |
| 116 | filelock==3.0.12 | |
| 117 | nltk==3.5 | |
| 118 | imbalanced-learn==0.6.2 | |
| 119 | cymem==2.0.5 | |
| 120 | async-generator==1.10 | |
| 121 | distlib==0.3.2 | |
| 122 | murmurhash==1.0.5 | |
| 123 | jdcal==1.4.1 | |
| 124 | typeguard==2.12.1 | |
| 125 | thinc==7.4.1 | |
| 126 | regex==2021.8.3 | |
| 127 | tensorflow-federated==0.17.0 | |
| 128 | nbclient==0.5.0 | |
| 129 | catalogue==1.0.0 | |
| 130 | packaging==21.0 | |
| 131 | tf-slim==1.1.0 | |
| 132 | tensorflow-estimator==2.3.0 | |
| 133 | importlib-metadata==3.7.2 | |
| 134 | pygame==2.0.1 | |
| 135 | s3transfer==0.3.3 | |
| 136 | cairocffi==1.2.0 | |
| 137 | rouge==1.0.0 | |
| 138 | numpyencoder==0.3.0 | |
| 139 | greenlet==1.1.1 | |
| 140 | calysto==1.0.6 | |
| 141 | rsa==4.7.2 | |
| 142 | wrapt==1.12.1 |
| 0 | # coding:utf8 | |
| 1 | import visdom | |
| 2 | import torch as t | |
| 3 | import time | |
| 4 | import torchvision as tv | |
| 5 | import numpy as np | |
| 6 | ||
| 7 | ||
| 8 | class Visualizer(): | |
| 9 | """ | |
| 10 | 封装了visdom的基本操作,但是你仍然可以通过`self.vis.function` | |
| 11 | 调用原生的visdom接口 | |
| 12 | """ | |
| 13 | ||
| 14 | def __init__(self, env='default', **kwargs): | |
| 15 | import visdom | |
| 16 | self.vis = visdom.Visdom(env=env, use_incoming_socket=False, **kwargs) | |
| 17 | ||
| 18 | # 画的第几个数,相当于横座标 | |
| 19 | # 保存(’loss',23) 即loss的第23个点 | |
| 20 | self.index = {} | |
| 21 | self.log_text = '' | |
| 22 | ||
| 23 | def reinit(self, env='default', **kwargs): | |
| 24 | """ | |
| 25 | 修改visdom的配置 | |
| 26 | """ | |
| 27 | self.vis = visdom.Visdom(env=env,use_incoming_socket=False, **kwargs) | |
| 28 | return self | |
| 29 | ||
| 30 | def plot_many(self, d): | |
| 31 | """ | |
| 32 | 一次plot多个 | |
| 33 | @params d: dict (name,value) i.e. ('loss',0.11) | |
| 34 | """ | |
| 35 | for k, v in d.items(): | |
| 36 | self.plot(k, v) | |
| 37 | ||
| 38 | def img_many(self, d): | |
| 39 | for k, v in d.items(): | |
| 40 | self.img(k, v) | |
| 41 | ||
| 42 | def plot(self, name, y): | |
| 43 | """ | |
| 44 | self.plot('loss',1.00) | |
| 45 | """ | |
| 46 | x = self.index.get(name, 0) | |
| 47 | self.vis.line(Y=np.array([y]), X=np.array([x]), | |
| 48 | win=name, | |
| 49 | opts=dict(title=name), | |
| 50 | update=None if x == 0 else 'append' | |
| 51 | ) | |
| 52 | self.index[name] = x + 1 | |
| 53 | ||
| 54 | def img(self, name, img_): | |
| 55 | """ | |
| 56 | self.img('input_img',t.Tensor(64,64)) | |
| 57 | """ | |
| 58 | ||
| 59 | if len(img_.size()) < 3: | |
| 60 | img_ = img_.cpu().unsqueeze(0) | |
| 61 | self.vis.image(img_.cpu(), | |
| 62 | win=name, | |
| 63 | opts=dict(title=name) | |
| 64 | ) | |
| 65 | ||
| 66 | def img_grid_many(self, d): | |
| 67 | for k, v in d.items(): | |
| 68 | self.img_grid(k, v) | |
| 69 | ||
| 70 | def img_grid(self, name, input_3d): | |
| 71 | """ | |
| 72 | 一个batch的图片转成一个网格图,i.e. input(36,64,64) | |
| 73 | 会变成 6*6 的网格图,每个格子大小64*64 | |
| 74 | """ | |
| 75 | self.img(name, tv.utils.make_grid( | |
| 76 | input_3d.cpu()[0].unsqueeze(1).clamp(max=1, min=0))) | |
| 77 | ||
| 78 | def log(self, info, win='log_text'): | |
| 79 | """ | |
| 80 | self.log({'loss':1,'lr':0.0001}) | |
| 81 | """ | |
| 82 | ||
| 83 | self.log_text += ('[{time}] {info} <br>'.format( | |
| 84 | time=time.strftime('%m%d_%H%M%S'), | |
| 85 | info=info)) | |
| 86 | self.vis.text(self.log_text, win=win) | |
| 87 | ||
| 88 | def __getattr__(self, name): | |
| 89 | return getattr(self.vis, name) |