File size: 4,327 Bytes
178c15c
1c17493
 
 
 
49568b9
 
178c15c
506cafc
3b0ba0b
a2f7c2e
 
3b0ba0b
506cafc
9bc2398
7abf85b
9bc2398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3b0ba0b
506cafc
3b0ba0b
7f56d6f
 
 
19fd366
7f56d6f
 
 
 
19fd366
7f56d6f
 
 
9b0a205
7f56d6f
 
 
 
 
 
 
19fd366
7f56d6f
 
 
 
 
 
19fd366
7f56d6f
19fd366
 
7f56d6f
 
 
 
 
67780a2
7f56d6f
8a2f7a3
871dd49
9e25fd3
871dd49
ea5077f
506cafc
b7f4e53
a2f7c2e
f88f1c1
ea5077f
 
506cafc
ea5077f
5b609e2
 
506cafc
 
774c9ef
5b609e2
506cafc
a2f7c2e
f5991a6
a2f7c2e
9e7aa25
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Cosmos QA (gpt2) 
> This is part of the
[Flax/Jax Community Week](https://discuss.huggingface.co/t/train-a-gpt2-model-for-contextual-common-sense-reasoning-using-the-cosmos-qa-dataset/7463), organized by [HuggingFace](https://huggingface.co/) and TPU usage sponsored by Google.

## Team Members
-Rohan V Kashyap ([Rohan](https://huggingface.co/Rohan))
-Vivek V Kashyap ([Vivek](https://huggingface.co/Vivek))

## Dataset

[Cosmos QA: Machine Reading Comprehension with Contextual Commonsense Reasoning](https://huggingface.co/datasets/cosmos_qa).This dataset contains a set of 35,600 problems that require commonsense-based reading comprehension, formulated as multiple-choice questions.Understanding narratives requires reading between the lines, which in turn, requires interpreting the likely causes and effects of events, even when they are not mentioned explicitly.The questions focus on factual and literal understanding of the context paragraph, our dataset focuses on reading between the lines over a diverse collection of people's everyday narratives.


### Example

```json
{"Context":["It's a very humbling experience when you need someone
to dress you every morning, tie your shoes, and put your hair
up. Every menial task takes an unprecedented amount of effort.
It made me appreciate Dan even more. But anyway I shan't
dwell on this (I'm not dying after all) and not let it detract from
my lovely 5 days with my friends visiting from Jersey."],

"Question":["What's a possible reason the writer needed someone to
dress him every morning?"],

"Multiple Choice":["A: The writer doesn't like putting effort into these tasks.",
"B: The writer has a physical disability.",
"C: The writer is bad at doing his own hair.",
"D: None of the above choices."]
"link":"https://arxiv.org/pdf/1909.00277.pdf"
}
```

## How to use

```bash
# Installing requirements
pip install transformers
pip install datasets 
```

```python
from model_file import FlaxGPT2ForMultipleChoice
from datasets import Dataset
model_path="flax-community/gpt2-Cosmos"
model = FlaxGPT2ForMultipleChoice.from_pretrained(model_path,input_shape=(1,4,1))

dataset=Dataset.from_csv('./')

def preprocess(example):
    example['context&question']=example['context']+example['question']
    example['first_sentence']=[example['context&question'],example['context&question'],example['context&question'],example['context&question']]
    example['second_sentence']=example['answer0'],example['answer1'],example['answer2'],example['answer3']
    return example
    
dataset=dataset.map(preprocess)

def tokenize(examples):
    a=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax')
    a['labels']=examples['label']
    return a
    
dataset=dataset.map(tokenize)

input_id=jnp.array(dataset['input_ids'])
att_mask=jnp.array(dataset['attention_mask'])

outputs=model(input_id,att_mask)

final_output=jnp.argmax(outputs,axis=-1)

print(f"the predction of the dataset : {final_output}")
```

```
The Correct answer:-Option 1 
```

## Preprocessing

The texts are tokenized using the GPT2 tokenizer.To feed the inputs of multiple choice we concatenated context and question as first input and all the 4 possible choices as the second input to our tokenizer.

## Evaluation

The following tables summarize the scores obtained by the **GPT2-CosmosQA**.The ones  marked as (^) are the baseline models.

|      Model      |  Dev Acc | Test Acc  | 
|:---------------:|:-----:|:-----:|
| BERT-FT Multiway^| 68.3.| 68.4  |
|  GPT-FT   ^    |  54.0 | 54.4. |
| GPT2-CosmosQA  | 60.3 | 59.7 | 

## Inference

This project was mainly  to test the  common sense understanding of the  GPT2-model.We finetuned on a Dataset known as CosmosQ requires reasoning beyond the exact text spans in the context.The above results shows that GPT2 model is doing better than most of the base line models given that  it only used to predict the next word in the pre-training objective.


## Credits
  Huge thanks to Huggingface 🤗 & Google Jax/Flax team for such a wonderful community week. Especially for providing such massive computing resource. Big thanks to [@patil-suraj](https://github.com/patil-suraj) & [@patrickvonplaten](https://github.com/patrickvonplaten) for mentoring during whole week.