chentianqi commited on
Commit
b202053
·
verified ·
1 Parent(s): 673ac4e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +343 -2
README.md CHANGED
@@ -20,7 +20,7 @@ tags:
20
  ---
21
 
22
 
23
- This model has been quantized using [GPTQModel](https://github.com/ModelCloud/GPTQModel).
24
 
25
  - **bits**: 4
26
  - **dynamic**: null
@@ -38,8 +38,349 @@ This model has been quantized using [GPTQModel](https://github.com/ModelCloud/GP
38
  - **damp_percent**: 0.1
39
  - **damp_auto_increment**: 0.0015
40
 
 
 
 
 
 
 
 
 
41
 
42
  ## Example:
43
  ```python
44
- TODO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  ```
 
20
  ---
21
 
22
 
23
+ This model has been 4-bit quantized Llada-8B-Base model with [GPTQModel](https://github.com/ModelCloud/GPTQModel).
24
 
25
  - **bits**: 4
26
  - **dynamic**: null
 
38
  - **damp_percent**: 0.1
39
  - **damp_auto_increment**: 0.0015
40
 
41
+ ## Benchmark
42
+ ### Performance of Quantized Models
43
+
44
+ | Dataset | GPTQ-4bit | FP16 |
45
+ |----------------|-------------|------|
46
+ | mmlu | ✓ | ✓ |
47
+ | cmmlu | ✓ | ✓ |
48
+ | arc_challenge | ✓ | ✓ |
49
 
50
  ## Example:
51
  ```python
52
+ '''
53
+ This file is inspired by the code from https://github.com/ML-GSAI/SMDM
54
+ '''
55
+ import accelerate
56
+ import torch
57
+ import re
58
+ from pathlib import Path
59
+ import random
60
+ import numpy as np
61
+ import torch.nn.functional as F
62
+ from datasets import Dataset
63
+ from lm_eval.__main__ import cli_evaluate
64
+ from lm_eval.api.instance import Instance
65
+ from lm_eval.api.model import LM
66
+ from lm_eval.models.huggingface import HFLM
67
+ from lm_eval.api.registry import register_model
68
+ from tqdm import tqdm
69
+
70
+ from transformers import AutoTokenizer, AutoModel
71
+ from gptqmodel import GPTQModel
72
+
73
+
74
+
75
+ @register_model("llada_dist")
76
+ class LLaDAEvalHarness(LM):
77
+ def __init__(
78
+ self,
79
+ model_path='',
80
+ mask_id=126336,
81
+ max_length=4096,
82
+ block_length = 4096,
83
+ steps = 128,
84
+ batch_size=32,
85
+ mc_num=128,
86
+ is_check_greedy=True,
87
+ cfg=0.,
88
+ device="cuda",
89
+ gptqmodel=True
90
+ ):
91
+ """
92
+ Args:
93
+ model_path: LLaDA-8B-Base model path.
94
+ mask_id: The token id of [MASK] is 126336.
95
+ max_length: the max sequence length.
96
+ batch_size: mini batch size.
97
+ mc_num: Monte Carlo estimation iterations
98
+ is_check_greedy: For certain metrics like LAMBADA, the evaluation requires the model to verify whether the answer
99
+ is generated through greedy sampling conditioned on the prompt (note that this differs from conditional
100
+ generation). We implement this verification through the suffix_greedy_prediction() function, which
101
+ returns a True/False judgment used for accuracy calculation.
102
+ When is_check_greedy is set to True, the lm-evaluation-harness library automatically invokes this function.
103
+ However, since none of the metrics in the LLaDA paper (https://arxiv.org/abs/2502.09992) require this functionality,
104
+ we recommend setting is_check_greedy to False. This configuration causes suffix_greedy_prediction() to return False
105
+ by default, significantly accelerating the evaluation process.
106
+ cfg_scale: Unsupervised classifier-free guidance scale.
107
+ """
108
+ super().__init__()
109
+
110
+ accelerator = accelerate.Accelerator()
111
+ if accelerator.num_processes > 1:
112
+ self.accelerator = accelerator
113
+ else:
114
+ self.accelerator = None
115
+
116
+ model_kwargs = {}
117
+ if self.accelerator is not None:
118
+ model_kwargs.update({'device_map': {'': f'{self.accelerator.device}'}})
119
+
120
+ #self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, gptqmodel=gptqmodel, **model_kwargs)
121
+ self.model = GPTQModel.load(model_path, device='cuda' , trust_remote_code=True )
122
+ self.model.eval()
123
+
124
+ self.device = torch.device(device)
125
+ if self.accelerator is not None:
126
+ self.model = self.accelerator.prepare(self.model)
127
+ self.device = torch.device(f'{self.accelerator.device}')
128
+ self._rank = self.accelerator.local_process_index
129
+ self._world_size = self.accelerator.num_processes
130
+
131
+ self.mask_id = mask_id
132
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
133
+
134
+ self.mc_num = mc_num
135
+ self.batch_size = int(batch_size)
136
+ assert mc_num % self.batch_size == 0
137
+ self.sampling_eps = 0.
138
+ self.max_length = max_length
139
+ self.block_length = block_length
140
+ self.steps = steps
141
+ self.is_check_greedy = is_check_greedy
142
+
143
+ self.cfg = cfg
144
+ print(f'model: {model_path}')
145
+ print(f'Is check greedy: {is_check_greedy}')
146
+ print(f'cfg: {cfg}')
147
+
148
+ @property
149
+ def rank(self):
150
+ return self._rank
151
+
152
+ @property
153
+ def world_size(self):
154
+ return self._world_size
155
+
156
+ def _forward_process(self, batch, prompt_index):
157
+ b, l = batch.shape
158
+
159
+ target_len = (l - prompt_index.sum()).item()
160
+ k = torch.randint(1, target_len + 1, (), device=batch.device)
161
+
162
+ x = torch.round(torch.linspace(float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device)).long()
163
+ x = ((x - 1) % target_len) + 1
164
+ assert x.min() >= 1 and x.max() <= target_len
165
+
166
+ indices = torch.arange(target_len, device=batch.device).repeat(b, 1)
167
+ is_mask = indices < x.unsqueeze(1)
168
+
169
+ for i in range(b):
170
+ is_mask[i] = is_mask[i][torch.randperm(target_len)]
171
+
172
+ is_mask = torch.cat((torch.zeros(b, prompt_index.sum(), dtype=torch.bool, device=batch.device), is_mask), dim=1)
173
+
174
+ noisy_batch = torch.where(is_mask, self.mask_id, batch)
175
+
176
+ return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l)
177
+
178
+ @torch.no_grad()
179
+ def get_logits(self, batch, prompt_index):
180
+ if self.cfg > 0.:
181
+ assert len(prompt_index) == batch.shape[1]
182
+ prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
183
+ un_batch = batch.clone()
184
+ un_batch[prompt_index] = self.mask_id
185
+ batch = torch.cat([batch, un_batch])
186
+
187
+ logits = self.model(batch).logits
188
+
189
+ if self.cfg > 0.:
190
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
191
+ logits = un_logits + (self.cfg + 1) * (logits - un_logits)
192
+ return logits[:, :batch.shape[1]]
193
+
194
+ @torch.no_grad()
195
+ def get_loglikelihood(self, prefix, target):
196
+ seq = torch.concatenate([prefix, target])[None, :]
197
+ seq = seq.repeat((self.batch_size, 1)).to(self.device)
198
+
199
+ prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
200
+
201
+ loss_acc = []
202
+ for _ in range(self.mc_num // self.batch_size):
203
+ perturbed_seq, p_mask = self._forward_process(seq, prompt_index)
204
+
205
+ mask_indices = perturbed_seq == self.mask_id
206
+
207
+ logits = self.get_logits(perturbed_seq, prompt_index)
208
+
209
+ loss = F.cross_entropy(logits[mask_indices], seq[mask_indices], reduction='none') / p_mask[mask_indices]
210
+ loss = loss.sum() / self.batch_size
211
+ loss_acc.append(loss.item())
212
+
213
+ return - sum(loss_acc) / len(loss_acc)
214
+
215
+ @torch.no_grad()
216
+ def suffix_greedy_prediction(self, prefix, target):
217
+ if not self.is_check_greedy:
218
+ return False
219
+
220
+ seq = torch.full((1, len(prefix) + len(target)), self.mask_id, device=self.device)
221
+ prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
222
+ prefix, target = prefix.to(self.device), target.to(self.device)
223
+ seq[0, :len(prefix)] = prefix
224
+
225
+ for i in range(len(target)):
226
+ mask_index = (seq == self.mask_id)
227
+ logits = self.get_logits(seq, prompt_index)[mask_index]
228
+ x0 = torch.argmax(logits, dim=-1)
229
+
230
+ p = torch.softmax(logits.to(torch.float32), dim=-1)
231
+ confidence = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze(dim=-1)
232
+ _, index = torch.sort(confidence, descending=True)
233
+ x0[index[1:]] = self.mask_id
234
+ seq[mask_index] = x0.clone()
235
+ correct = target == seq[0, len(prefix):]
236
+ correct = torch.all(correct)
237
+ return correct
238
+
239
+ def _encode_pair(self, context, continuation):
240
+ n_spaces = len(context) - len(context.rstrip())
241
+ if n_spaces > 0:
242
+ continuation = context[-n_spaces:] + continuation
243
+ context = context[:-n_spaces]
244
+
245
+ whole_enc = self.tokenizer(context + continuation)["input_ids"]
246
+ context_enc = self.tokenizer(context)["input_ids"]
247
+
248
+ context_enc_len = len(context_enc)
249
+ continuation_enc = whole_enc[context_enc_len:]
250
+
251
+ return context_enc, continuation_enc
252
+
253
+ def loglikelihood(self, requests):
254
+ def _tokenize(e):
255
+ prefix, target = self._encode_pair(e["prefix"], e["target"])
256
+ return {
257
+ "prefix_text": e["prefix"],
258
+ "target_text": e["target"],
259
+ "prefix": prefix,
260
+ "target": target,
261
+ }
262
+
263
+ ds = []
264
+ ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
265
+ ds = Dataset.from_list(ds)
266
+ ds = ds.map(_tokenize)
267
+ ds = ds.with_format("torch")
268
+ prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds]
269
+
270
+ assert max(prompt_len) <= 4096
271
+
272
+ out = []
273
+ with torch.no_grad():
274
+ for elem in tqdm(ds, desc="Computing likelihood..."):
275
+ prefix = elem["prefix"]
276
+ target = elem["target"]
277
+
278
+ ll = self.get_loglikelihood(prefix, target)
279
+
280
+ is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target)
281
+
282
+ out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
283
+ print('=' * 20)
284
+ print('prefix: ', elem['prefix_text'])
285
+ print('target: ', elem['target_text'])
286
+ print(ll, is_target_greedy_dec)
287
+ print('=' * 20, end='\n\n')
288
+ torch.cuda.empty_cache()
289
+ return out
290
+
291
+ def loglikelihood_rolling(self, requests):
292
+
293
+ raise NotImplementedError
294
+ def generate_until(self, context, max_length, stop, **generation_kwargs):
295
+ raise NotImplementedError
296
+ @torch.no_grad()
297
+ def _model_generate(self, context, max_length, stop, **generation_kwargs):
298
+ '''
299
+ Args:
300
+ model: Mask predictor.
301
+ prompt: A tensor of shape (1, l).
302
+ steps: Sampling steps, less than or equal to gen_length.
303
+ gen_length: Generated answer length.
304
+ block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
305
+ temperature: Categorical distribution sampling temperature.
306
+ cfg_scale: Unsupervised classifier-free guidance scale.
307
+ remasking: Remasking strategy. 'low_confidence' or 'random'.
308
+ mask_id: The toke id of [MASK] is 126336.
309
+ '''
310
+
311
+ # using the hyperparams in orginal paper
312
+ prompt = context
313
+
314
+ #
315
+ gen_length = self.max_length
316
+ block_length = self.block_length
317
+ steps = self.max_length
318
+ temperature=0.
319
+ cfg_scale=0.
320
+ remasking='low_confidence'
321
+ mask_id=126336
322
+
323
+
324
+ x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(self.model.device)
325
+ x[:, :prompt.shape[1]] = prompt.clone()
326
+
327
+ prompt_index = (x != mask_id)
328
+
329
+ assert gen_length % block_length == 0
330
+ num_blocks = gen_length // block_length
331
+
332
+ assert steps % num_blocks == 0
333
+ steps = steps // num_blocks
334
+
335
+ for num_block in range(num_blocks):
336
+ block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
337
+ num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
338
+ for i in range(steps):
339
+
340
+ mask_index = (x == mask_id)
341
+ if cfg_scale > 0.:
342
+ un_x = x.clone()
343
+ un_x[prompt_index] = mask_id
344
+ x_ = torch.cat([x, un_x], dim=0)
345
+ logits = self.model(x_).logits
346
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
347
+ logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
348
+ else:
349
+ logits = self.model(x).logits
350
+
351
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
352
+ x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
353
+
354
+ if remasking == 'low_confidence':
355
+ p = F.softmax(logits.to(torch.float64), dim=-1)
356
+ x0_p = torch.squeeze(
357
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
358
+ elif remasking == 'random':
359
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
360
+ else:
361
+ raise NotImplementedError(remasking)
362
+
363
+ x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf
364
+
365
+ x0 = torch.where(mask_index, x0, x)
366
+ confidence = torch.where(mask_index, x0_p, -np.inf)
367
+
368
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
369
+ for j in range(confidence.shape[0]):
370
+ _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
371
+ transfer_index[j, select_index] = True
372
+ x[transfer_index] = x0[transfer_index]
373
+
374
+ return x
375
+
376
+
377
+ if __name__ == "__main__":
378
+ set_seed(1234)
379
+ cli_evaluate()
380
+
381
+
382
+ ```
383
+
384
+ ```bash
385
+ accelerate launch eval_llada_gptq.py --tasks arc_challenge --num_fewshot 0 --model llada_dist --batch_size 8 --model_args model_path=FunAGI/LLaDA-8B-Base-gptqmodel-4bit,cfg=0.5,is_check_greedy=False,mc_num=128
386
  ```