import transformers from transformers import TextStreamer import torch from transformers.generation.streamers import BaseStreamer from datasets import load_dataset import random class TokenStreamer(BaseStreamer): """ Simple token streamer that prints each token surrounded by brackets as soon as it's generated. Parameters: tokenizer (`AutoTokenizer`): The tokenizer used to decode the tokens. skip_prompt (`bool`, *optional*, defaults to `False`): Whether to skip the prompt tokens in the output. Useful for chatbots. """ def __init__(self, tokenizer, skip_prompt=True): self.tokenizer = tokenizer self.skip_prompt = skip_prompt self.next_tokens_are_prompt = True def put(self, value): if len(value.shape) > 1 and value.shape[0] > 1: raise ValueError("TokenStreamer only supports batch size 1") elif len(value.shape) > 1: value = value[0] if self.skip_prompt and self.next_tokens_are_prompt: self.next_tokens_are_prompt = False return for token_id in value.tolist(): token_text = self.tokenizer.decode([token_id]) print(f"={repr(token_text)}", end="\n", flush=True) def end(self): self.next_tokens_are_prompt = True print() model_id = "../" tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token model = transformers.AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True ) pipeline = transformers.pipeline( "text-generation", model=model, tokenizer=tokenizer, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto", trust_remote_code=True ) dataset = load_dataset("EdinburghNLP/xsum", split="test") random.seed(0) sampled_indices = random.sample(range(len(dataset)), 100) sampled_dataset = dataset.select(sampled_indices) streamer = TokenStreamer(tokenizer) for sample in sampled_dataset: document = sample["document"] prompt = "Please copy this paragraph: " + document + " Directly output the copied paragraph here: " messages = [ {"role": "user", "content": prompt} ] print("===") outputs = pipeline( messages, max_new_tokens=64, do_sample=True, temperature=0.6, top_p=1.0, streamer=streamer, ) print("===")