Text Generation
Transformers
Safetensors
English
ddllama
conversational
custom_code
xuan-luo's picture
Upload ffn_allocation/eval_copy.py with huggingface_hub
305a7b1 verified
raw
history blame
2.55 kB
import transformers
from transformers import TextStreamer
import torch
from transformers.generation.streamers import BaseStreamer
from datasets import load_dataset
import random
class TokenStreamer(BaseStreamer):
"""
Simple token streamer that prints each token surrounded by brackets as soon as it's generated.
Parameters:
tokenizer (`AutoTokenizer`):
The tokenizer used to decode the tokens.
skip_prompt (`bool`, *optional*, defaults to `False`):
Whether to skip the prompt tokens in the output. Useful for chatbots.
"""
def __init__(self, tokenizer, skip_prompt=True):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.next_tokens_are_prompt = True
def put(self, value):
if len(value.shape) > 1 and value.shape[0] > 1:
raise ValueError("TokenStreamer only supports batch size 1")
elif len(value.shape) > 1:
value = value[0]
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return
for token_id in value.tolist():
token_text = self.tokenizer.decode([token_id])
print(f"={repr(token_text)}", end="\n", flush=True)
def end(self):
self.next_tokens_are_prompt = True
print()
model_id = "../"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
pipeline = transformers.pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
trust_remote_code=True
)
dataset = load_dataset("EdinburghNLP/xsum", split="test")
random.seed(0)
sampled_indices = random.sample(range(len(dataset)), 100)
sampled_dataset = dataset.select(sampled_indices)
streamer = TokenStreamer(tokenizer)
for sample in sampled_dataset:
document = sample["document"]
prompt = "Please copy this paragraph: <paragraph>" + document + "</paragraph> Directly output the copied paragraph here: "
messages = [
{"role": "user", "content": prompt}
]
print("===")
outputs = pipeline(
messages,
max_new_tokens=64,
do_sample=True,
temperature=0.6,
top_p=1.0,
streamer=streamer,
)
print("===")