Text Generation
Transformers
Safetensors
English
ddllama
conversational
custom_code
xuan-luo commited on
Commit
f409ba5
·
verified ·
1 Parent(s): 725f3e7

Upload ffn_allocation/eval_xsum.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ffn_allocation/eval_xsum.py +85 -0
ffn_allocation/eval_xsum.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ from transformers import TextStreamer
3
+ import torch
4
+ from transformers.generation.streamers import BaseStreamer
5
+ from datasets import load_dataset
6
+ import random
7
+
8
+ class TokenStreamer(BaseStreamer):
9
+ """
10
+ Simple token streamer that prints each token surrounded by brackets as soon as it's generated.
11
+
12
+ Parameters:
13
+ tokenizer (`AutoTokenizer`):
14
+ The tokenizer used to decode the tokens.
15
+ skip_prompt (`bool`, *optional*, defaults to `False`):
16
+ Whether to skip the prompt tokens in the output. Useful for chatbots.
17
+ """
18
+
19
+ def __init__(self, tokenizer, skip_prompt=True):
20
+ self.tokenizer = tokenizer
21
+ self.skip_prompt = skip_prompt
22
+ self.next_tokens_are_prompt = True
23
+
24
+ def put(self, value):
25
+ if len(value.shape) > 1 and value.shape[0] > 1:
26
+ raise ValueError("TokenStreamer only supports batch size 1")
27
+ elif len(value.shape) > 1:
28
+ value = value[0]
29
+
30
+ if self.skip_prompt and self.next_tokens_are_prompt:
31
+ self.next_tokens_are_prompt = False
32
+ return
33
+
34
+ for token_id in value.tolist():
35
+ token_text = self.tokenizer.decode([token_id])
36
+ print(f"={repr(token_text)}", end="\n", flush=True)
37
+
38
+ def end(self):
39
+ self.next_tokens_are_prompt = True
40
+ print()
41
+
42
+ model_id = "../"
43
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
44
+ tokenizer.pad_token = tokenizer.eos_token
45
+ model = transformers.AutoModelForCausalLM.from_pretrained(
46
+ model_id,
47
+ torch_dtype=torch.bfloat16,
48
+ device_map="auto",
49
+ trust_remote_code=True
50
+ )
51
+
52
+ pipeline = transformers.pipeline(
53
+ "text-generation",
54
+ model=model,
55
+ tokenizer=tokenizer,
56
+ model_kwargs={"torch_dtype": torch.bfloat16},
57
+ device_map="auto",
58
+ trust_remote_code=True
59
+ )
60
+
61
+ dataset = load_dataset("EdinburghNLP/xsum", split="test")
62
+
63
+ random.seed(0)
64
+ sampled_indices = random.sample(range(len(dataset)), 100)
65
+ sampled_dataset = dataset.select(sampled_indices)
66
+
67
+ streamer = TokenStreamer(tokenizer)
68
+ for sample in sampled_dataset:
69
+ document = sample["document"]
70
+ prompt = "Please summarize this paragraph into a single sentence: <paragraph>" + document + "</paragraph> Directly output the summarized paragraph here: "
71
+
72
+ messages = [
73
+ {"role": "user", "content": prompt}
74
+ ]
75
+
76
+ print("===")
77
+ outputs = pipeline(
78
+ messages,
79
+ max_new_tokens=64,
80
+ do_sample=True,
81
+ temperature=0.6,
82
+ top_p=1.0,
83
+ streamer=streamer,
84
+ )
85
+ print("===")