Update README.md
Browse files
README.md
CHANGED
@@ -20,7 +20,7 @@ tags:
|
|
20 |
---
|
21 |
|
22 |
|
23 |
-
This model has been quantized
|
24 |
|
25 |
- **bits**: 4
|
26 |
- **dynamic**: null
|
@@ -38,8 +38,349 @@ This model has been quantized using [GPTQModel](https://github.com/ModelCloud/GP
|
|
38 |
- **damp_percent**: 0.1
|
39 |
- **damp_auto_increment**: 0.0015
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
## Example:
|
43 |
```python
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
```
|
|
|
20 |
---
|
21 |
|
22 |
|
23 |
+
This model has been 4-bit quantized Llada-8B-Base model with [GPTQModel](https://github.com/ModelCloud/GPTQModel).
|
24 |
|
25 |
- **bits**: 4
|
26 |
- **dynamic**: null
|
|
|
38 |
- **damp_percent**: 0.1
|
39 |
- **damp_auto_increment**: 0.0015
|
40 |
|
41 |
+
## Benchmark
|
42 |
+
### Performance of Quantized Models
|
43 |
+
|
44 |
+
| Dataset | GPTQ-4bit | FP16 |
|
45 |
+
|----------------|-------------|------|
|
46 |
+
| mmlu | ✓ | ✓ |
|
47 |
+
| cmmlu | ✓ | ✓ |
|
48 |
+
| arc_challenge | ✓ | ✓ |
|
49 |
|
50 |
## Example:
|
51 |
```python
|
52 |
+
'''
|
53 |
+
This file is inspired by the code from https://github.com/ML-GSAI/SMDM
|
54 |
+
'''
|
55 |
+
import accelerate
|
56 |
+
import torch
|
57 |
+
import re
|
58 |
+
from pathlib import Path
|
59 |
+
import random
|
60 |
+
import numpy as np
|
61 |
+
import torch.nn.functional as F
|
62 |
+
from datasets import Dataset
|
63 |
+
from lm_eval.__main__ import cli_evaluate
|
64 |
+
from lm_eval.api.instance import Instance
|
65 |
+
from lm_eval.api.model import LM
|
66 |
+
from lm_eval.models.huggingface import HFLM
|
67 |
+
from lm_eval.api.registry import register_model
|
68 |
+
from tqdm import tqdm
|
69 |
+
|
70 |
+
from transformers import AutoTokenizer, AutoModel
|
71 |
+
from gptqmodel import GPTQModel
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
@register_model("llada_dist")
|
76 |
+
class LLaDAEvalHarness(LM):
|
77 |
+
def __init__(
|
78 |
+
self,
|
79 |
+
model_path='',
|
80 |
+
mask_id=126336,
|
81 |
+
max_length=4096,
|
82 |
+
block_length = 4096,
|
83 |
+
steps = 128,
|
84 |
+
batch_size=32,
|
85 |
+
mc_num=128,
|
86 |
+
is_check_greedy=True,
|
87 |
+
cfg=0.,
|
88 |
+
device="cuda",
|
89 |
+
gptqmodel=True
|
90 |
+
):
|
91 |
+
"""
|
92 |
+
Args:
|
93 |
+
model_path: LLaDA-8B-Base model path.
|
94 |
+
mask_id: The token id of [MASK] is 126336.
|
95 |
+
max_length: the max sequence length.
|
96 |
+
batch_size: mini batch size.
|
97 |
+
mc_num: Monte Carlo estimation iterations
|
98 |
+
is_check_greedy: For certain metrics like LAMBADA, the evaluation requires the model to verify whether the answer
|
99 |
+
is generated through greedy sampling conditioned on the prompt (note that this differs from conditional
|
100 |
+
generation). We implement this verification through the suffix_greedy_prediction() function, which
|
101 |
+
returns a True/False judgment used for accuracy calculation.
|
102 |
+
When is_check_greedy is set to True, the lm-evaluation-harness library automatically invokes this function.
|
103 |
+
However, since none of the metrics in the LLaDA paper (https://arxiv.org/abs/2502.09992) require this functionality,
|
104 |
+
we recommend setting is_check_greedy to False. This configuration causes suffix_greedy_prediction() to return False
|
105 |
+
by default, significantly accelerating the evaluation process.
|
106 |
+
cfg_scale: Unsupervised classifier-free guidance scale.
|
107 |
+
"""
|
108 |
+
super().__init__()
|
109 |
+
|
110 |
+
accelerator = accelerate.Accelerator()
|
111 |
+
if accelerator.num_processes > 1:
|
112 |
+
self.accelerator = accelerator
|
113 |
+
else:
|
114 |
+
self.accelerator = None
|
115 |
+
|
116 |
+
model_kwargs = {}
|
117 |
+
if self.accelerator is not None:
|
118 |
+
model_kwargs.update({'device_map': {'': f'{self.accelerator.device}'}})
|
119 |
+
|
120 |
+
#self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16, gptqmodel=gptqmodel, **model_kwargs)
|
121 |
+
self.model = GPTQModel.load(model_path, device='cuda' , trust_remote_code=True )
|
122 |
+
self.model.eval()
|
123 |
+
|
124 |
+
self.device = torch.device(device)
|
125 |
+
if self.accelerator is not None:
|
126 |
+
self.model = self.accelerator.prepare(self.model)
|
127 |
+
self.device = torch.device(f'{self.accelerator.device}')
|
128 |
+
self._rank = self.accelerator.local_process_index
|
129 |
+
self._world_size = self.accelerator.num_processes
|
130 |
+
|
131 |
+
self.mask_id = mask_id
|
132 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
133 |
+
|
134 |
+
self.mc_num = mc_num
|
135 |
+
self.batch_size = int(batch_size)
|
136 |
+
assert mc_num % self.batch_size == 0
|
137 |
+
self.sampling_eps = 0.
|
138 |
+
self.max_length = max_length
|
139 |
+
self.block_length = block_length
|
140 |
+
self.steps = steps
|
141 |
+
self.is_check_greedy = is_check_greedy
|
142 |
+
|
143 |
+
self.cfg = cfg
|
144 |
+
print(f'model: {model_path}')
|
145 |
+
print(f'Is check greedy: {is_check_greedy}')
|
146 |
+
print(f'cfg: {cfg}')
|
147 |
+
|
148 |
+
@property
|
149 |
+
def rank(self):
|
150 |
+
return self._rank
|
151 |
+
|
152 |
+
@property
|
153 |
+
def world_size(self):
|
154 |
+
return self._world_size
|
155 |
+
|
156 |
+
def _forward_process(self, batch, prompt_index):
|
157 |
+
b, l = batch.shape
|
158 |
+
|
159 |
+
target_len = (l - prompt_index.sum()).item()
|
160 |
+
k = torch.randint(1, target_len + 1, (), device=batch.device)
|
161 |
+
|
162 |
+
x = torch.round(torch.linspace(float(k), k + (b - 1) * (target_len / b), steps=b, device=batch.device)).long()
|
163 |
+
x = ((x - 1) % target_len) + 1
|
164 |
+
assert x.min() >= 1 and x.max() <= target_len
|
165 |
+
|
166 |
+
indices = torch.arange(target_len, device=batch.device).repeat(b, 1)
|
167 |
+
is_mask = indices < x.unsqueeze(1)
|
168 |
+
|
169 |
+
for i in range(b):
|
170 |
+
is_mask[i] = is_mask[i][torch.randperm(target_len)]
|
171 |
+
|
172 |
+
is_mask = torch.cat((torch.zeros(b, prompt_index.sum(), dtype=torch.bool, device=batch.device), is_mask), dim=1)
|
173 |
+
|
174 |
+
noisy_batch = torch.where(is_mask, self.mask_id, batch)
|
175 |
+
|
176 |
+
return noisy_batch, (x / target_len).unsqueeze(1).repeat(1, l)
|
177 |
+
|
178 |
+
@torch.no_grad()
|
179 |
+
def get_logits(self, batch, prompt_index):
|
180 |
+
if self.cfg > 0.:
|
181 |
+
assert len(prompt_index) == batch.shape[1]
|
182 |
+
prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
|
183 |
+
un_batch = batch.clone()
|
184 |
+
un_batch[prompt_index] = self.mask_id
|
185 |
+
batch = torch.cat([batch, un_batch])
|
186 |
+
|
187 |
+
logits = self.model(batch).logits
|
188 |
+
|
189 |
+
if self.cfg > 0.:
|
190 |
+
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
191 |
+
logits = un_logits + (self.cfg + 1) * (logits - un_logits)
|
192 |
+
return logits[:, :batch.shape[1]]
|
193 |
+
|
194 |
+
@torch.no_grad()
|
195 |
+
def get_loglikelihood(self, prefix, target):
|
196 |
+
seq = torch.concatenate([prefix, target])[None, :]
|
197 |
+
seq = seq.repeat((self.batch_size, 1)).to(self.device)
|
198 |
+
|
199 |
+
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
|
200 |
+
|
201 |
+
loss_acc = []
|
202 |
+
for _ in range(self.mc_num // self.batch_size):
|
203 |
+
perturbed_seq, p_mask = self._forward_process(seq, prompt_index)
|
204 |
+
|
205 |
+
mask_indices = perturbed_seq == self.mask_id
|
206 |
+
|
207 |
+
logits = self.get_logits(perturbed_seq, prompt_index)
|
208 |
+
|
209 |
+
loss = F.cross_entropy(logits[mask_indices], seq[mask_indices], reduction='none') / p_mask[mask_indices]
|
210 |
+
loss = loss.sum() / self.batch_size
|
211 |
+
loss_acc.append(loss.item())
|
212 |
+
|
213 |
+
return - sum(loss_acc) / len(loss_acc)
|
214 |
+
|
215 |
+
@torch.no_grad()
|
216 |
+
def suffix_greedy_prediction(self, prefix, target):
|
217 |
+
if not self.is_check_greedy:
|
218 |
+
return False
|
219 |
+
|
220 |
+
seq = torch.full((1, len(prefix) + len(target)), self.mask_id, device=self.device)
|
221 |
+
prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
|
222 |
+
prefix, target = prefix.to(self.device), target.to(self.device)
|
223 |
+
seq[0, :len(prefix)] = prefix
|
224 |
+
|
225 |
+
for i in range(len(target)):
|
226 |
+
mask_index = (seq == self.mask_id)
|
227 |
+
logits = self.get_logits(seq, prompt_index)[mask_index]
|
228 |
+
x0 = torch.argmax(logits, dim=-1)
|
229 |
+
|
230 |
+
p = torch.softmax(logits.to(torch.float32), dim=-1)
|
231 |
+
confidence = torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)).squeeze(dim=-1)
|
232 |
+
_, index = torch.sort(confidence, descending=True)
|
233 |
+
x0[index[1:]] = self.mask_id
|
234 |
+
seq[mask_index] = x0.clone()
|
235 |
+
correct = target == seq[0, len(prefix):]
|
236 |
+
correct = torch.all(correct)
|
237 |
+
return correct
|
238 |
+
|
239 |
+
def _encode_pair(self, context, continuation):
|
240 |
+
n_spaces = len(context) - len(context.rstrip())
|
241 |
+
if n_spaces > 0:
|
242 |
+
continuation = context[-n_spaces:] + continuation
|
243 |
+
context = context[:-n_spaces]
|
244 |
+
|
245 |
+
whole_enc = self.tokenizer(context + continuation)["input_ids"]
|
246 |
+
context_enc = self.tokenizer(context)["input_ids"]
|
247 |
+
|
248 |
+
context_enc_len = len(context_enc)
|
249 |
+
continuation_enc = whole_enc[context_enc_len:]
|
250 |
+
|
251 |
+
return context_enc, continuation_enc
|
252 |
+
|
253 |
+
def loglikelihood(self, requests):
|
254 |
+
def _tokenize(e):
|
255 |
+
prefix, target = self._encode_pair(e["prefix"], e["target"])
|
256 |
+
return {
|
257 |
+
"prefix_text": e["prefix"],
|
258 |
+
"target_text": e["target"],
|
259 |
+
"prefix": prefix,
|
260 |
+
"target": target,
|
261 |
+
}
|
262 |
+
|
263 |
+
ds = []
|
264 |
+
ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
|
265 |
+
ds = Dataset.from_list(ds)
|
266 |
+
ds = ds.map(_tokenize)
|
267 |
+
ds = ds.with_format("torch")
|
268 |
+
prompt_len = [len(x["prefix"]) + len(x["target"]) for x in ds]
|
269 |
+
|
270 |
+
assert max(prompt_len) <= 4096
|
271 |
+
|
272 |
+
out = []
|
273 |
+
with torch.no_grad():
|
274 |
+
for elem in tqdm(ds, desc="Computing likelihood..."):
|
275 |
+
prefix = elem["prefix"]
|
276 |
+
target = elem["target"]
|
277 |
+
|
278 |
+
ll = self.get_loglikelihood(prefix, target)
|
279 |
+
|
280 |
+
is_target_greedy_dec = self.suffix_greedy_prediction(prefix, target)
|
281 |
+
|
282 |
+
out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
|
283 |
+
print('=' * 20)
|
284 |
+
print('prefix: ', elem['prefix_text'])
|
285 |
+
print('target: ', elem['target_text'])
|
286 |
+
print(ll, is_target_greedy_dec)
|
287 |
+
print('=' * 20, end='\n\n')
|
288 |
+
torch.cuda.empty_cache()
|
289 |
+
return out
|
290 |
+
|
291 |
+
def loglikelihood_rolling(self, requests):
|
292 |
+
|
293 |
+
raise NotImplementedError
|
294 |
+
def generate_until(self, context, max_length, stop, **generation_kwargs):
|
295 |
+
raise NotImplementedError
|
296 |
+
@torch.no_grad()
|
297 |
+
def _model_generate(self, context, max_length, stop, **generation_kwargs):
|
298 |
+
'''
|
299 |
+
Args:
|
300 |
+
model: Mask predictor.
|
301 |
+
prompt: A tensor of shape (1, l).
|
302 |
+
steps: Sampling steps, less than or equal to gen_length.
|
303 |
+
gen_length: Generated answer length.
|
304 |
+
block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
|
305 |
+
temperature: Categorical distribution sampling temperature.
|
306 |
+
cfg_scale: Unsupervised classifier-free guidance scale.
|
307 |
+
remasking: Remasking strategy. 'low_confidence' or 'random'.
|
308 |
+
mask_id: The toke id of [MASK] is 126336.
|
309 |
+
'''
|
310 |
+
|
311 |
+
# using the hyperparams in orginal paper
|
312 |
+
prompt = context
|
313 |
+
|
314 |
+
#
|
315 |
+
gen_length = self.max_length
|
316 |
+
block_length = self.block_length
|
317 |
+
steps = self.max_length
|
318 |
+
temperature=0.
|
319 |
+
cfg_scale=0.
|
320 |
+
remasking='low_confidence'
|
321 |
+
mask_id=126336
|
322 |
+
|
323 |
+
|
324 |
+
x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(self.model.device)
|
325 |
+
x[:, :prompt.shape[1]] = prompt.clone()
|
326 |
+
|
327 |
+
prompt_index = (x != mask_id)
|
328 |
+
|
329 |
+
assert gen_length % block_length == 0
|
330 |
+
num_blocks = gen_length // block_length
|
331 |
+
|
332 |
+
assert steps % num_blocks == 0
|
333 |
+
steps = steps // num_blocks
|
334 |
+
|
335 |
+
for num_block in range(num_blocks):
|
336 |
+
block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
|
337 |
+
num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)
|
338 |
+
for i in range(steps):
|
339 |
+
|
340 |
+
mask_index = (x == mask_id)
|
341 |
+
if cfg_scale > 0.:
|
342 |
+
un_x = x.clone()
|
343 |
+
un_x[prompt_index] = mask_id
|
344 |
+
x_ = torch.cat([x, un_x], dim=0)
|
345 |
+
logits = self.model(x_).logits
|
346 |
+
logits, un_logits = torch.chunk(logits, 2, dim=0)
|
347 |
+
logits = un_logits + (cfg_scale + 1) * (logits - un_logits)
|
348 |
+
else:
|
349 |
+
logits = self.model(x).logits
|
350 |
+
|
351 |
+
logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
|
352 |
+
x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
|
353 |
+
|
354 |
+
if remasking == 'low_confidence':
|
355 |
+
p = F.softmax(logits.to(torch.float64), dim=-1)
|
356 |
+
x0_p = torch.squeeze(
|
357 |
+
torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
|
358 |
+
elif remasking == 'random':
|
359 |
+
x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
|
360 |
+
else:
|
361 |
+
raise NotImplementedError(remasking)
|
362 |
+
|
363 |
+
x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf
|
364 |
+
|
365 |
+
x0 = torch.where(mask_index, x0, x)
|
366 |
+
confidence = torch.where(mask_index, x0_p, -np.inf)
|
367 |
+
|
368 |
+
transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
|
369 |
+
for j in range(confidence.shape[0]):
|
370 |
+
_, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j, i])
|
371 |
+
transfer_index[j, select_index] = True
|
372 |
+
x[transfer_index] = x0[transfer_index]
|
373 |
+
|
374 |
+
return x
|
375 |
+
|
376 |
+
|
377 |
+
if __name__ == "__main__":
|
378 |
+
set_seed(1234)
|
379 |
+
cli_evaluate()
|
380 |
+
|
381 |
+
|
382 |
+
```
|
383 |
+
|
384 |
+
```bash
|
385 |
+
accelerate launch eval_llada_gptq.py --tasks arc_challenge --num_fewshot 0 --model llada_dist --batch_size 8 --model_args model_path=FunAGI/LLaDA-8B-Base-gptqmodel-4bit,cfg=0.5,is_check_greedy=False,mc_num=128
|
386 |
```
|