Upload ffn_allocation/eval_xsum.py with huggingface_hub
Browse files- ffn_allocation/eval_xsum.py +85 -0
    	
        ffn_allocation/eval_xsum.py
    ADDED
    
    | @@ -0,0 +1,85 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import transformers
         | 
| 2 | 
            +
            from transformers import TextStreamer
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from transformers.generation.streamers import BaseStreamer
         | 
| 5 | 
            +
            from datasets import load_dataset
         | 
| 6 | 
            +
            import random
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            class TokenStreamer(BaseStreamer):
         | 
| 9 | 
            +
                """
         | 
| 10 | 
            +
                Simple token streamer that prints each token surrounded by brackets as soon as it's generated.
         | 
| 11 | 
            +
                
         | 
| 12 | 
            +
                Parameters:
         | 
| 13 | 
            +
                    tokenizer (`AutoTokenizer`):
         | 
| 14 | 
            +
                        The tokenizer used to decode the tokens.
         | 
| 15 | 
            +
                    skip_prompt (`bool`, *optional*, defaults to `False`):
         | 
| 16 | 
            +
                        Whether to skip the prompt tokens in the output. Useful for chatbots.
         | 
| 17 | 
            +
                """
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                def __init__(self, tokenizer, skip_prompt=True):
         | 
| 20 | 
            +
                    self.tokenizer = tokenizer
         | 
| 21 | 
            +
                    self.skip_prompt = skip_prompt
         | 
| 22 | 
            +
                    self.next_tokens_are_prompt = True
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def put(self, value):
         | 
| 25 | 
            +
                    if len(value.shape) > 1 and value.shape[0] > 1:
         | 
| 26 | 
            +
                        raise ValueError("TokenStreamer only supports batch size 1")
         | 
| 27 | 
            +
                    elif len(value.shape) > 1:
         | 
| 28 | 
            +
                        value = value[0]
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    if self.skip_prompt and self.next_tokens_are_prompt:
         | 
| 31 | 
            +
                        self.next_tokens_are_prompt = False
         | 
| 32 | 
            +
                        return
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    for token_id in value.tolist():
         | 
| 35 | 
            +
                        token_text = self.tokenizer.decode([token_id])
         | 
| 36 | 
            +
                        print(f"={repr(token_text)}", end="\n", flush=True)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                def end(self):
         | 
| 39 | 
            +
                    self.next_tokens_are_prompt = True
         | 
| 40 | 
            +
                    print()
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            model_id = "../"
         | 
| 43 | 
            +
            tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
         | 
| 44 | 
            +
            tokenizer.pad_token = tokenizer.eos_token
         | 
| 45 | 
            +
            model = transformers.AutoModelForCausalLM.from_pretrained(
         | 
| 46 | 
            +
                model_id,
         | 
| 47 | 
            +
                torch_dtype=torch.bfloat16,
         | 
| 48 | 
            +
                device_map="auto",
         | 
| 49 | 
            +
                trust_remote_code=True
         | 
| 50 | 
            +
            )
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            pipeline = transformers.pipeline(
         | 
| 53 | 
            +
                "text-generation",
         | 
| 54 | 
            +
                model=model,
         | 
| 55 | 
            +
                tokenizer=tokenizer,
         | 
| 56 | 
            +
                model_kwargs={"torch_dtype": torch.bfloat16},
         | 
| 57 | 
            +
                device_map="auto",
         | 
| 58 | 
            +
                trust_remote_code=True
         | 
| 59 | 
            +
            )
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            dataset = load_dataset("EdinburghNLP/xsum", split="test")
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            random.seed(0)
         | 
| 64 | 
            +
            sampled_indices = random.sample(range(len(dataset)), 100)
         | 
| 65 | 
            +
            sampled_dataset = dataset.select(sampled_indices)
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            streamer = TokenStreamer(tokenizer)
         | 
| 68 | 
            +
            for sample in sampled_dataset:
         | 
| 69 | 
            +
                document = sample["document"]
         | 
| 70 | 
            +
                prompt = "Please summarize this paragraph into a single sentence: <paragraph>" + document + "</paragraph> Directly output the summarized paragraph here: "
         | 
| 71 | 
            +
                
         | 
| 72 | 
            +
                messages = [
         | 
| 73 | 
            +
                    {"role": "user", "content": prompt}
         | 
| 74 | 
            +
                ]
         | 
| 75 | 
            +
                
         | 
| 76 | 
            +
                print("===")
         | 
| 77 | 
            +
                outputs = pipeline(
         | 
| 78 | 
            +
                    messages,
         | 
| 79 | 
            +
                    max_new_tokens=64,
         | 
| 80 | 
            +
                    do_sample=True,
         | 
| 81 | 
            +
                    temperature=0.6,
         | 
| 82 | 
            +
                    top_p=1.0,
         | 
| 83 | 
            +
                    streamer=streamer,
         | 
| 84 | 
            +
                )
         | 
| 85 | 
            +
                print("===") 
         |