# -*- coding: utf-8 -*-
import time
import math
import random
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import transformers
from tensorboardX import SummaryWriter # 这个库提供Tensorboard的日志功能
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.utils.data import Dataset, DataLoader
from peft import TaskType, LoraConfig, get_peft_model
MODEL_PATH = r"IDEA-CCNL/Wenzhong-GPT2-110M"
# MODEL_PATH = r"IDEA-CCNL/Wenzhong2.0-GPT2-110M-BertTokenizer-chinese"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
tokenizer.pad_token = tokenizer.eos_token
prompts = [
"写一首唐诗:",
"写一首古诗:",
"写一首绝句:",
"生成一首唐诗:",
"生成一首古诗:",
"生成一首绝句:",
"这是一首唐诗:",
"这是一首古诗:",
"这是一首古代诗歌:",
"生成一首压韵的古诗:",
]
class PoemDataset(Dataset):
def __init__(self, path):
super().__init__()
self.poems = open(path, encoding='utf-8').readlines() # [:30000]
def __getitem__(self, idx):
text = self.poems[idx].strip()
return text
def __len__(self):
return len(self.poems)
# def collate_fn(samples):
# return tokenizer(samples, padding="longest", truncation=True, return_tensors='pt')
#
#
# def generate_prompt(bs, prompt="写一首唐诗:"):
# return tokenizer([prompt] * bs, padding="longest", truncation=True, return_tensors='pt')
def prepare_inputs(samples):
prompt_inputs = tokenizer([random.choice(prompts)] * len(samples),
padding="longest", truncation=True, add_special_tokens=False, return_tensors='pt')
inputs = tokenizer(samples, padding="longest", truncation=True, add_special_tokens=False, return_tensors='pt')
input_ids = torch.cat([
prompt_inputs.input_ids,
inputs.input_ids
], dim=1)
label_ids = torch.cat([
torch.ones_like(prompt_inputs.input_ids) * -100, # no loss for prompt
inputs.input_ids
], dim=1)
attention_mask = torch.cat([prompt_inputs.attention_mask, inputs.attention_mask], dim=1)
return input_ids, attention_mask, label_ids
@torch.no_grad()
def inference(model):
inputs = tokenizer("写一首唐诗:折戟沉沙铁未销,自将磨洗认前朝。", add_special_tokens=False, return_tensors='pt')
inputs = inputs.to(device)
outputs = model.generate(
**inputs,
return_dict_in_generate=True,
max_length=150,
do_sample=True,
top_p=0.6,
num_return_sequences=5
)
for line in tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True):
print(line)
@torch.no_grad()
def evaluate(model, test_loader, epoch, **kwargs):
loss_list = []
ppl_list = []
print("-" * 20 + " Evaluating " + "-" * 20)
model.eval()
with torch.no_grad():
for local_step, inputs in enumerate(tqdm(test_loader)):
# prepare prompts
# prompt = generate_prompt(inputs.input_ids.shape[0]).to(device)
# inputs = inputs.to(device)
# input_ids = torch.cat([prompt.input_ids, inputs.input_ids], dim=1)
# no_loss_ids = torch.ones_like(prompt.input_ids) * -100
# label_ids = torch.cat([no_loss_ids, inputs.input_ids], dim=1)
# attention_mask = torch.cat([prompt.attention_mask, inputs.attention_mask], dim=1)
input_ids, attention_mask, label_ids = prepare_inputs(inputs)
outputs = model(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
labels=label_ids.to(device),
return_dict=True,
)
loss_list.append(outputs.loss.cpu().item())
# PPL
probs = torch.softmax(outputs.logits, dim=-1).max(dim=-1)[0] # BLC->BL
ppl = torch.exp(-probs.log().mean(-1))
ppl_list.append(ppl.mean().cpu().item())
# log
avg_loss = sum(loss_list) / len(loss_list)
avg_ppl = sum(ppl_list) / len(ppl_list)
print(f"Epoch {epoch}/{kwargs['max_epochs']} | loss:{avg_loss:.5f} | ppl: {avg_ppl: 5f}")
# writer.add_scalar('Test/Loss', sum(loss_list) / len(loss_list), epoch * kwargs['steps_per_epoch'])
def train(model: nn.Module, dataloaders, optimizer, scheduler, **kwargs):
"""训练的主函数"""
train_loader, test_loader = dataloaders
device = kwargs['device']
writer = SummaryWriter(kwargs['logger_name'])
model = model.to(device)
for epoch in range(kwargs['max_epochs']):
# ========== Train ==========
loss_list = []
model.train()
last_time = time.time()
for local_step, inputs in enumerate(train_loader):
step = epoch * kwargs['steps_per_epoch'] + local_step
# prepare prompts
# prompt = generate_prompt(inputs.input_ids.shape[0]).to(device)
# inputs = inputs.to(device)
# input_ids = torch.cat([prompt.input_ids, inputs.input_ids], dim=1)
# no_loss_ids = torch.ones_like(prompt.input_ids) * -100 # Don't cal loss of prompts
# label_ids = torch.cat([no_loss_ids, inputs.input_ids], dim=1)
# attention_mask = torch.cat([prompt.attention_mask, inputs.attention_mask], dim=1)
input_ids, attention_mask, label_ids = prepare_inputs(inputs)
optimizer.zero_grad()
with torch.autocast(device_type='cuda'): # fp16
outputs = model(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
labels=label_ids.to(device),
return_dict=True,
)
loss = outputs.loss
loss.backward()
optimizer.step()
scheduler.step()
# log
loss_list.append(loss.detach().cpu().item())
if (local_step % 50 == 0 and local_step != 0) or local_step == kwargs['steps_per_epoch'] - 1:
avg_loss = sum(loss_list) / len(loss_list)
n_step_time = time.time() - last_time
left_time = (kwargs['steps_per_epoch'] - local_step) // 50 * n_step_time
print("Epoch {}/{} | Step {}/{} | loss:{:.5f} time:{:.1f}s left:{:.1f}m".format(
epoch, kwargs['max_epochs'], local_step, kwargs['steps_per_epoch'],
avg_loss, n_step_time, left_time / 60
))
last_time = time.time()
writer.add_scalar('Train/Loss', loss, step)
writer.add_scalar('Epoch', epoch, step)
# if local_step %
# torch.save(model.named_parameters(), "checkpoint.pth")
model.save_pretrained("/home/jovyan/work/checkpoint_lora_v4")
# ========== Eval ==========
evaluate(model, test_loader, epoch, **kwargs)
print("=" * 53)
# ========== Inference ==========
inference(model)
# prepare data
train_dataset = PoemDataset("/home/jovyan/work/data/train_poems_v2.txt")
test_dataset = PoemDataset("/home/jovyan/work/data/test_poems_v2.txt")
# Hyperparameter
batch_size = 32
lr = 1e-4
device = torch.device('cuda')
max_epochs = 4
num_warmup_steps = 200
num_training_steps = max_epochs * math.ceil(len(train_dataset) / batch_size)
seed = 2023
logger_name = "/home/jovyan/work/results/logs"
# prepare model
gpt2_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH)
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=32, lora_alpha=16, lora_dropout=0.1, bias='all'
)
gpt2_model = get_peft_model(gpt2_model, peft_config)
gpt2_model.print_trainable_parameters()
# prepare optimizer
optim = torch.optim.AdamW(filter(lambda p: p.requires_grad, gpt2_model.parameters()), lr=lr)
sche = transformers.get_linear_schedule_with_warmup(optim, num_warmup_steps, num_training_steps)
# train
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
train(
gpt2_model,
(train_dataloader, val_dataloader),
optimizer=optim,
scheduler=sche,
device=device,
max_epochs=max_epochs,
logger_name=logger_name,
steps_per_epoch=len(train_dataloader)
)