master
/ poem_generate_ft.py

poem_generate_ft.py @master raw · history · blame

# -*- coding: utf-8 -*-
"""Jiayi-GPT2.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/18rdH-DK5ukrs9Xhy6R0kcZkwMWJ5_MVK
"""

# !pip install transformers
# !pip install tensorboardX

# !rm -r Jiayi_GPT2

import time
import math
import random

import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import transformers
from tensorboardX import SummaryWriter  # 这个库提供Tensorboard的日志功能
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import Dataset, DataLoader
from peft import TaskType, LoraConfig, get_peft_model, PeftModel

# MODEL_PATH = r"IDEA-CCNL/Wenzhong-GPT2-110M"
MODEL_PATH = r"IDEA-CCNL/Wenzhong2.0-GPT2-3.5B-chinese"
# MODEL_PATH = r"IDEA-CCNL/Wenzhong2.0-GPT2-110M-BertTokenizer-chinese"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.pad_token = tokenizer.eos_token


class TagPoemDataset(Dataset):
    def __init__(self, path):
        super().__init__()
        self.poems = open(path, encoding='utf-8').readlines()

    def __getitem__(self, idx):
        tag, poem = self.poems[idx].strip().split('|')
        tag = "写一首关于" + tag.replace(' ', '、') + "的古诗:"
        return tag, poem

    def __len__(self):
        return len(self.poems)


def prepare_inputs(samples):
    tags, poems = samples
    bs = len(tags)

    prompt_inputs = tokenizer(tags, add_special_tokens=False)
    poems_inputs = tokenizer(poems, add_special_tokens=False)
    prompt_len = [len(i) for i in prompt_inputs.input_ids]
    input_ids_list = [x1 + x2 for x1, x2 in zip(prompt_inputs.input_ids, poems_inputs.input_ids)]
    max_len = max([len(i) for i in input_ids_list])

    input_ids = torch.ones(bs, max_len, dtype=torch.long) * tokenizer.pad_token_id
    attention_mask = torch.zeros(bs, max_len, dtype=torch.long)
    for i in range(bs):
        input_ids[i, :len(input_ids_list[i])] = torch.tensor(input_ids_list[i])
        attention_mask[i, :len(input_ids_list[i])] = 1
    label_ids = input_ids.clone()
    for i in range(bs):
        label_ids[i, :prompt_len[i]] = -100

    return input_ids, attention_mask, label_ids


@torch.no_grad()
def inference(model):
    inputs = tokenizer("写一首关于思乡、女子的古诗:", add_special_tokens=False, return_tensors='pt')
    inputs = inputs.to(device)
    outputs = model.generate(
        **inputs,
        return_dict_in_generate=True,
        max_length=150,
        do_sample=True,
        top_p=0.9,
        num_return_sequences=5,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
    )
    print(tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True))


def train(model: nn.Module, dataloaders, optimizer, scheduler, **kwargs):
    """训练的主函数"""
    train_loader = dataloaders[0]
    device = kwargs['device']
    # writer = SummaryWriter(kwargs['logger_name'])
    model = model.to(device)

    for epoch in range(kwargs['max_epochs']):
        # ==========  Train  ==========
        loss_list = []
        model.train()
        last_time = time.time()
        for local_step, inputs in enumerate(train_loader):
            step = epoch * kwargs['steps_per_epoch'] + local_step
            input_ids, attention_mask, label_ids = prepare_inputs(inputs)

            optimizer.zero_grad()

            with torch.autocast(device_type='cuda'):  # fp16
                outputs = model(
                    input_ids=input_ids.to(device),
                    attention_mask=attention_mask.to(device),
                    labels=label_ids.to(device),
                    return_dict=True,
                )
                loss = outputs.loss

            loss.backward()
            optimizer.step()
            scheduler.step()

            # log
            loss_list.append(loss.detach().cpu().item())
            if (local_step % 50 == 0 and local_step != 0) or local_step == kwargs['steps_per_epoch'] - 1:
                avg_loss = sum(loss_list) / len(loss_list)
                n_step_time = time.time() - last_time
                left_time = (kwargs['steps_per_epoch'] - local_step) // 50 * n_step_time
                print("Epoch {}/{} | Step {}/{} | loss:{:.5f} time:{:.1f}s left:{:.1f}m".format(
                    epoch, kwargs['max_epochs'], local_step, kwargs['steps_per_epoch'],
                    avg_loss, n_step_time, left_time / 60
                ))
                last_time = time.time()
            # writer.add_scalar('Train/Loss', loss, step)
            # writer.add_scalar('Epoch', epoch, step)

            if local_step % 100 == 0:
                inference(model)

        model.save_pretrained("ft_checkpoint_lora_v4")


# prepare data
train_dataset = TagPoemDataset("tag_poems_v2.txt")
# test_dataset = TagPoemDataset("test_poems_v3.txt")

# Hyperparameter
batch_size = 4
lr = 5e-4
device = torch.device('cuda')
max_epochs = 10
num_warmup_steps = 50
num_training_steps = max_epochs * math.ceil(len(train_dataset) / batch_size)
seed = 2023
# logger_name = "Jiayi_GPT2_202307091400"

# prepare model
gpt2_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
# gpt2_model = PeftModel.from_pretrained(gpt2_model, 'checkpoint_lora_v4')
# gpt2_model.print_trainable_parameters()
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=32, lora_alpha=16, lora_dropout=0.1, bias='all'
)
gpt2_model = get_peft_model(gpt2_model, peft_config)
gpt2_model.print_trainable_parameters()

# prepare optimizer
optim = torch.optim.AdamW(filter(lambda p: p.requires_grad, gpt2_model.parameters()), lr=lr)
sche = transformers.get_linear_schedule_with_warmup(optim, num_warmup_steps, num_training_steps)

# train
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# val_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

train(
    gpt2_model,
    (train_dataloader, None),
    optimizer=optim,
    scheduler=sche,
    device=device,
    max_epochs=max_epochs,
    # logger_name=logger_name,
    steps_per_epoch=len(train_dataloader)
)

# tokenizer("写一首唐诗:", padding="longest", truncation=True, return_tensors='pt')