master
/ poem_generate.py

poem_generate.py @master raw · history · blame

# -*- coding: utf-8 -*-
import time
import math
import random

import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import transformers
from tensorboardX import SummaryWriter  # 这个库提供Tensorboard的日志功能
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import Dataset, DataLoader
from peft import TaskType, LoraConfig, get_peft_model

MODEL_PATH = r"IDEA-CCNL/Wenzhong-GPT2-110M"
# MODEL_PATH = r"IDEA-CCNL/Wenzhong2.0-GPT2-110M-BertTokenizer-chinese"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.pad_token = tokenizer.eos_token

prompts = [
    "写一首唐诗:",
    "写一首古诗:",
    "写一首绝句:",
    "生成一首唐诗:",
    "生成一首古诗:",
    "生成一首绝句:",
    "这是一首唐诗:",
    "这是一首古诗:",
    "这是一首古代诗歌:",
    "生成一首压韵的古诗:",
]


class PoemDataset(Dataset):
    def __init__(self, path):
        super().__init__()
        self.poems = open(path, encoding='utf-8').readlines()  # [:30000]

    def __getitem__(self, idx):
        text = self.poems[idx].strip()
        return text

    def __len__(self):
        return len(self.poems)


# def collate_fn(samples):
#     return tokenizer(samples, padding="longest", truncation=True, return_tensors='pt')
#
#
# def generate_prompt(bs, prompt="写一首唐诗:"):
#     return tokenizer([prompt] * bs, padding="longest", truncation=True, return_tensors='pt')


def prepare_inputs(samples):
    prompt_inputs = tokenizer([random.choice(prompts)] * len(samples),
                              padding="longest", truncation=True, add_special_tokens=False, return_tensors='pt')
    inputs = tokenizer(samples, padding="longest", truncation=True, add_special_tokens=False, return_tensors='pt')

    input_ids = torch.cat([
        prompt_inputs.input_ids,
        inputs.input_ids
    ], dim=1)
    label_ids = torch.cat([
        torch.ones_like(prompt_inputs.input_ids) * -100,  # no loss for prompt
        inputs.input_ids
    ], dim=1)
    attention_mask = torch.cat([prompt_inputs.attention_mask, inputs.attention_mask], dim=1)
    return input_ids, attention_mask, label_ids


@torch.no_grad()
def inference(model):
    inputs = tokenizer("写一首唐诗:折戟沉沙铁未销,自将磨洗认前朝。", add_special_tokens=False, return_tensors='pt')
    inputs = inputs.to(device)
    outputs = model.generate(
        **inputs,
        return_dict_in_generate=True,
        max_length=150,
        do_sample=True,
        top_p=0.6,
        num_return_sequences=5
    )
    for line in tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True):
        print(line)


@torch.no_grad()
def evaluate(model, test_loader, epoch, **kwargs):
    loss_list = []
    ppl_list = []
    print("-" * 20 + "   Evaluating   " + "-" * 20)
    model.eval()
    with torch.no_grad():
        for local_step, inputs in enumerate(tqdm(test_loader)):
            # prepare prompts
            # prompt = generate_prompt(inputs.input_ids.shape[0]).to(device)
            # inputs = inputs.to(device)
            # input_ids = torch.cat([prompt.input_ids, inputs.input_ids], dim=1)
            # no_loss_ids = torch.ones_like(prompt.input_ids) * -100
            # label_ids = torch.cat([no_loss_ids, inputs.input_ids], dim=1)
            # attention_mask = torch.cat([prompt.attention_mask, inputs.attention_mask], dim=1)
            input_ids, attention_mask, label_ids = prepare_inputs(inputs)

            outputs = model(
                input_ids=input_ids.to(device),
                attention_mask=attention_mask.to(device),
                labels=label_ids.to(device),
                return_dict=True,
            )
            loss_list.append(outputs.loss.cpu().item())
            # PPL
            probs = torch.softmax(outputs.logits, dim=-1).max(dim=-1)[0]  # BLC->BL
            ppl = torch.exp(-probs.log().mean(-1))
            ppl_list.append(ppl.mean().cpu().item())
        # log
        avg_loss = sum(loss_list) / len(loss_list)
        avg_ppl = sum(ppl_list) / len(ppl_list)
        print(f"Epoch {epoch}/{kwargs['max_epochs']} | loss:{avg_loss:.5f} | ppl: {avg_ppl: 5f}")
        # writer.add_scalar('Test/Loss', sum(loss_list) / len(loss_list), epoch * kwargs['steps_per_epoch'])


def train(model: nn.Module, dataloaders, optimizer, scheduler, **kwargs):
    """训练的主函数"""
    train_loader, test_loader = dataloaders
    device = kwargs['device']
    writer = SummaryWriter(kwargs['logger_name'])
    model = model.to(device)

    for epoch in range(kwargs['max_epochs']):
        # ==========  Train  ==========
        loss_list = []
        model.train()
        last_time = time.time()
        for local_step, inputs in enumerate(train_loader):
            step = epoch * kwargs['steps_per_epoch'] + local_step
            # prepare prompts
            # prompt = generate_prompt(inputs.input_ids.shape[0]).to(device)
            # inputs = inputs.to(device)
            # input_ids = torch.cat([prompt.input_ids, inputs.input_ids], dim=1)
            # no_loss_ids = torch.ones_like(prompt.input_ids) * -100  # Don't cal loss of prompts
            # label_ids = torch.cat([no_loss_ids, inputs.input_ids], dim=1)
            # attention_mask = torch.cat([prompt.attention_mask, inputs.attention_mask], dim=1)
            input_ids, attention_mask, label_ids = prepare_inputs(inputs)

            optimizer.zero_grad()

            with torch.autocast(device_type='cuda'):  # fp16
                outputs = model(
                    input_ids=input_ids.to(device),
                    attention_mask=attention_mask.to(device),
                    labels=label_ids.to(device),
                    return_dict=True,
                )
                loss = outputs.loss

            loss.backward()
            optimizer.step()
            scheduler.step()

            # log
            loss_list.append(loss.detach().cpu().item())
            if (local_step % 50 == 0 and local_step != 0) or local_step == kwargs['steps_per_epoch'] - 1:
                avg_loss = sum(loss_list) / len(loss_list)
                n_step_time = time.time() - last_time
                left_time = (kwargs['steps_per_epoch'] - local_step) // 50 * n_step_time
                print("Epoch {}/{} | Step {}/{} | loss:{:.5f} time:{:.1f}s left:{:.1f}m".format(
                    epoch, kwargs['max_epochs'], local_step, kwargs['steps_per_epoch'],
                    avg_loss, n_step_time, left_time / 60
                ))
                last_time = time.time()
            writer.add_scalar('Train/Loss', loss, step)
            writer.add_scalar('Epoch', epoch, step)

            # if local_step %
        # torch.save(model.named_parameters(), "checkpoint.pth")
        model.save_pretrained("/home/jovyan/work/checkpoint_lora_v4")
        # ==========  Eval  ==========
        evaluate(model, test_loader, epoch, **kwargs)
        print("=" * 53)
        # ==========  Inference  ==========
        inference(model)


# prepare data
train_dataset = PoemDataset("/home/jovyan/work/data/train_poems_v2.txt")
test_dataset = PoemDataset("/home/jovyan/work/data/test_poems_v2.txt")

# Hyperparameter
batch_size = 32
lr = 1e-4
device = torch.device('cuda')
max_epochs = 4
num_warmup_steps = 200
num_training_steps = max_epochs * math.ceil(len(train_dataset) / batch_size)
seed = 2023
logger_name = "/home/jovyan/work/results/logs"

# prepare model
gpt2_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=32, lora_alpha=16, lora_dropout=0.1, bias='all'
)
gpt2_model = get_peft_model(gpt2_model, peft_config)
gpt2_model.print_trainable_parameters()

# prepare optimizer
optim = torch.optim.AdamW(filter(lambda p: p.requires_grad, gpt2_model.parameters()), lr=lr)
sche = transformers.get_linear_schedule_with_warmup(optim, num_warmup_steps, num_training_steps)

# train
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

train(
    gpt2_model,
    (train_dataloader, val_dataloader),
    optimizer=optim,
    scheduler=sche,
    device=device,
    max_epochs=max_epochs,
    logger_name=logger_name,
    steps_per_epoch=len(train_dataloader)
)