Upload gpt2.py with huggingface_hub
Browse files
gpt2.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from datasets import load_dataset
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, GPT2Config, GPT2LMHeadModel, pipeline
|
4 |
+
import torch
|
5 |
+
torch.set_default_dtype(torch.float32)
|
6 |
+
|
7 |
+
class SLMGPT2:
|
8 |
+
def __init__(self, tokenizer_id, model_id):
|
9 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
10 |
+
self.dataset = load_dataset("tiny_shakespeare", trust_remote_code=True)
|
11 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
12 |
+
if self.tokenizer.pad_token is None:
|
13 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
14 |
+
|
15 |
+
self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
|
16 |
+
self.distilled_model = AutoModelForCausalLM.from_pretrained(model_id)
|
17 |
+
self.vocab_size=self.tokenizer.vocab_size
|
18 |
+
self.n_positions=128
|
19 |
+
self.n_ctx=128
|
20 |
+
self.n_embd=256
|
21 |
+
self.n_layer=4
|
22 |
+
self.n_head=4
|
23 |
+
print(f"device set: {self.device}")
|
24 |
+
self.config = GPT2Config(
|
25 |
+
vocab_size=self.vocab_size,
|
26 |
+
n_positions=self.n_positions,
|
27 |
+
n_ctx=self.n_ctx,
|
28 |
+
n_embd=self.n_embd,
|
29 |
+
n_layer=self.n_layer,
|
30 |
+
n_head=self.n_head
|
31 |
+
)
|
32 |
+
self.config.pad_token_id = self.tokenizer.pad_token_id
|
33 |
+
print(f"Vocab size: {self.config.vocab_size}, type: {type(self.config.vocab_size)}")
|
34 |
+
print(f"Embedding dim: {self.config.n_embd}, type: {type(self.config.n_embd)}")
|
35 |
+
self.model = GPT2LMHeadModel(self.config).to(device=self.device)
|
36 |
+
print(f"device set: {self.device}")
|
37 |
+
|
38 |
+
def tokenize(self, example):
|
39 |
+
tokenized = self.tokenizer(
|
40 |
+
example["text"],
|
41 |
+
truncation=True,
|
42 |
+
padding="max_length",
|
43 |
+
max_length=128
|
44 |
+
)
|
45 |
+
tokenized["labels"] = tokenized["input_ids"].copy()
|
46 |
+
|
47 |
+
# Mask padding tokens in the labels
|
48 |
+
pad_token_id = self.tokenizer.pad_token_id
|
49 |
+
tokenized["labels"] = [
|
50 |
+
(label if label != pad_token_id else -100)
|
51 |
+
for label in tokenized["labels"]
|
52 |
+
]
|
53 |
+
return tokenized
|
54 |
+
|
55 |
+
def train(self):
|
56 |
+
tokenized_dataset = self.dataset.map(self.tokenize, batched=True)
|
57 |
+
training_args = TrainingArguments(
|
58 |
+
output_dir="./results",
|
59 |
+
num_train_epochs=500,
|
60 |
+
per_device_train_batch_size=4,
|
61 |
+
per_device_eval_batch_size=4,
|
62 |
+
logging_steps=10,
|
63 |
+
save_steps=500,
|
64 |
+
learning_rate=5e-4,
|
65 |
+
weight_decay=0.01,
|
66 |
+
save_total_limit=1,
|
67 |
+
logging_dir="none"
|
68 |
+
)
|
69 |
+
|
70 |
+
trainer = Trainer(
|
71 |
+
model=self.model,
|
72 |
+
args=training_args,
|
73 |
+
train_dataset=tokenized_dataset["train"]
|
74 |
+
)
|
75 |
+
trainer.train()
|
76 |
+
print("training completed!!")
|
77 |
+
self.tokenizer.save_pretrained("./saved_model")
|
78 |
+
self.model.save_pretrained("./saved_model")
|
79 |
+
print("Model and tokenizer saved to ./saved_model")
|
80 |
+
|
81 |
+
def inference(self):
|
82 |
+
model_path = "./saved_model"
|
83 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
84 |
+
model = AutoModelForCausalLM.from_pretrained(model_path)
|
85 |
+
text_gen = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
86 |
+
result = text_gen("The universe is", max_length=128)
|
87 |
+
print(result[0]["generated_text"])
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
model_id = "gpt2"
|
91 |
+
tokenizer_id = "distilgpt2"
|
92 |
+
sml = SLMGPT2(model_id=model_id, tokenizer_id=tokenizer_id)
|
93 |
+
sml.train()
|
94 |
+
sml.inference()
|