Text Generation
Transformers
Safetensors
English
ddllama
conversational
custom_code
xuan-luo commited on
Commit
305a7b1
·
verified ·
1 Parent(s): e11766d

Upload ffn_allocation/eval_copy.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ffn_allocation/eval_copy.py +88 -0
ffn_allocation/eval_copy.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ from transformers import TextStreamer
3
+ import torch
4
+ from transformers.generation.streamers import BaseStreamer
5
+ from datasets import load_dataset
6
+ import random
7
+
8
+ class TokenStreamer(BaseStreamer):
9
+ """
10
+ Simple token streamer that prints each token surrounded by brackets as soon as it's generated.
11
+
12
+ Parameters:
13
+ tokenizer (`AutoTokenizer`):
14
+ The tokenizer used to decode the tokens.
15
+ skip_prompt (`bool`, *optional*, defaults to `False`):
16
+ Whether to skip the prompt tokens in the output. Useful for chatbots.
17
+ """
18
+
19
+ def __init__(self, tokenizer, skip_prompt=True):
20
+ self.tokenizer = tokenizer
21
+ self.skip_prompt = skip_prompt
22
+ self.next_tokens_are_prompt = True
23
+
24
+ def put(self, value):
25
+ if len(value.shape) > 1 and value.shape[0] > 1:
26
+ raise ValueError("TokenStreamer only supports batch size 1")
27
+ elif len(value.shape) > 1:
28
+ value = value[0]
29
+
30
+ if self.skip_prompt and self.next_tokens_are_prompt:
31
+ self.next_tokens_are_prompt = False
32
+ return
33
+
34
+ for token_id in value.tolist():
35
+ token_text = self.tokenizer.decode([token_id])
36
+ print(f"={repr(token_text)}", end="\n", flush=True)
37
+
38
+ def end(self):
39
+ self.next_tokens_are_prompt = True
40
+ print()
41
+
42
+ model_id = "../"
43
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
44
+ tokenizer.pad_token = tokenizer.eos_token
45
+ model = transformers.AutoModelForCausalLM.from_pretrained(
46
+ model_id,
47
+ torch_dtype=torch.bfloat16,
48
+ device_map="auto",
49
+ trust_remote_code=True
50
+ )
51
+
52
+ pipeline = transformers.pipeline(
53
+ "text-generation",
54
+ model=model,
55
+ tokenizer=tokenizer,
56
+ model_kwargs={"torch_dtype": torch.bfloat16},
57
+ device_map="auto",
58
+ trust_remote_code=True
59
+ )
60
+
61
+
62
+ dataset = load_dataset("EdinburghNLP/xsum", split="test")
63
+
64
+
65
+ random.seed(0)
66
+ sampled_indices = random.sample(range(len(dataset)), 100)
67
+ sampled_dataset = dataset.select(sampled_indices)
68
+
69
+
70
+ streamer = TokenStreamer(tokenizer)
71
+ for sample in sampled_dataset:
72
+ document = sample["document"]
73
+ prompt = "Please copy this paragraph: <paragraph>" + document + "</paragraph> Directly output the copied paragraph here: "
74
+
75
+ messages = [
76
+ {"role": "user", "content": prompt}
77
+ ]
78
+
79
+ print("===")
80
+ outputs = pipeline(
81
+ messages,
82
+ max_new_tokens=64,
83
+ do_sample=True,
84
+ temperature=0.6,
85
+ top_p=1.0,
86
+ streamer=streamer,
87
+ )
88
+ print("===")