Vivek commited on
Commit
a63e06d
·
1 Parent(s): a5159cb

final draft

Browse files
Files changed (1) hide show
  1. src/test.py +85 -0
src/test.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ print(jax.local_device_count())
3
+ import jax.numpy as jnp
4
+
5
+ import flax
6
+ import flax.linen as nn
7
+ from flax.core.frozen_dict import FrozenDict, unfreeze
8
+ from flax.training.common_utils import get_metrics,onehot,shard,shard_prng_key
9
+
10
+ from typing import Any, Optional, Tuple
11
+
12
+ from transformers import (
13
+ GPT2Config)
14
+
15
+ import transformers
16
+ from transformers import GPT2Tokenizer
17
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2",pad_token='<|endoftext|>')
18
+ from datasets import load_dataset,load_metric
19
+
20
+ from model_file import FlaxGPT2ForMultipleChoice
21
+
22
+ import logging
23
+
24
+ logger = logging.getLogger()
25
+ logger.setLevel(logging.INFO)
26
+
27
+ dataset=load_dataset('cosmos_qa')
28
+
29
+ len_test_dataset=6963
30
+
31
+ test_dataset=dataset['test'].select(range(len_test_dataset))
32
+
33
+ def preprocess(example):
34
+ example['context&question']=example['context']+example['question']
35
+ example['first_sentence']=[example['context&question'],example['context&question'],example['context&question'],example['context&question']]
36
+ example['second_sentence']=example['answer0'],example['answer1'],example['answer2'],example['answer3']
37
+ return example
38
+
39
+ test_dataset=test_dataset.map(preprocess)
40
+
41
+ def tokenize(examples):
42
+ a=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax')
43
+ a['labels']=examples['label']
44
+ return a
45
+
46
+ test_dataset=test_dataset.map(tokenize)
47
+
48
+ remov_col=['id', 'context', 'question', 'answer0', 'answer1', 'answer2', 'answer3', 'labels', 'context&question', 'first_sentence', 'second_sentence']
49
+
50
+ test_dataset=test_dataset.remove_columns(remov_col)
51
+
52
+ seed=0
53
+ total_batch_size=32
54
+
55
+ model = FlaxGPT2ForMultipleChoice.from_pretrained("flax-community/gpt2-Cosmos",input_shape=(1,4,1))
56
+
57
+ def glue_train_data_loader(rng,dataset,batch_size):
58
+ steps_per_epoch=len_test_dataset//batch_size
59
+ perms=jax.random.permutation(rng,len(dataset))
60
+ perms=perms[:steps_per_epoch*batch_size]
61
+ perms=perms.reshape((steps_per_epoch,batch_size))
62
+ for perm in perms:
63
+ batch=dataset[perm]
64
+ batch={k:jnp.array(v) for k,v in batch.items()}
65
+ batch=shard(batch)
66
+ yield batch
67
+
68
+ rng=jax.random.PRNGKey(seed)
69
+ dropout_rngs=jax.random.split(rng,jax.local_device_count())
70
+
71
+ input_id=jnp.array(test_dataset['input_ids'])
72
+ att_mask=jnp.array(test_dataset['attention_mask'])
73
+
74
+ restored_output=[]
75
+ rng, input_rng = jax.random.split(rng)
76
+
77
+ for idx,batch in enumerate(glue_train_data_loader(input_rng, test_dataset, total_batch_size)):
78
+ outputs=model(batch['input_ids'],batch['attention_mask'])
79
+ final_output=jnp.argmax(outputs,axis=-1)
80
+ restored_output.append(final_output)
81
+
82
+ #outputs=model(input_id,att_mask)
83
+ #final_output=jnp.argmax(outputs,axis=-1)
84
+
85
+ logger.info(f"the predction of the test dataset : {restored_output[:30]}")