Vivek commited on
Commit
7f56d6f
·
1 Parent(s): 7abf85b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +39 -0
README.md CHANGED
@@ -31,4 +31,43 @@ dress him every morning?"],
31
  }
32
  ```
33
 
 
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  }
32
  ```
33
 
34
+ ##How to use
35
 
36
+ ```bash
37
+ # Installing requirements
38
+ pip install transformers
39
+ ```
40
+
41
+ ```python
42
+ from model_file import FlaxGPT2ForMultipleChoice
43
+ model_path="flax-community/gpt2-Cosmos"
44
+ model = FlaxGPT2ForMultipleChoice.from_pretrained(model_path,input_shape=(1,4,1))
45
+
46
+ run_dataset=Dataset.from_csv('......')
47
+
48
+ def preprocess(example):
49
+ example['context&question']=example['context']+example['question']
50
+ example['first_sentence']=[example['context&question'],example['context&question'],example['context&question'],example['context&question']]
51
+ example['second_sentence']=example['answer0'],example['answer1'],example['answer2'],example['answer3']
52
+ return example
53
+
54
+ run_dataset=run_dataset.map(preprocess)
55
+
56
+ def tokenize(examples):
57
+ a=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax')
58
+ a['labels']=examples['label']
59
+ return a
60
+
61
+ run_dataset=run_dataset.map(tokenize)
62
+
63
+ input_id=jnp.array(run_dataset['input_ids'])
64
+ att_mask=jnp.array(run_dataset['attention_mask'])
65
+
66
+ outputs=model(input_id,att_mask)
67
+
68
+ final_output=jnp.argmax(outputs,axis=-1)
69
+
70
+ Print(f"the predction of the dataset : {final_output}")
71
+ ```
72
+
73
+