Vivek commited on
Commit
19fd366
·
1 Parent(s): 871dd49

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +7 -5
README.md CHANGED
@@ -36,14 +36,16 @@ dress him every morning?"],
36
  ```bash
37
  # Installing requirements
38
  pip install transformers
 
39
  ```
40
 
41
  ```python
42
  from model_file import FlaxGPT2ForMultipleChoice
 
43
  model_path="flax-community/gpt2-Cosmos"
44
  model = FlaxGPT2ForMultipleChoice.from_pretrained(model_path,input_shape=(1,4,1))
45
 
46
- run_dataset=Dataset.from_csv('......')
47
 
48
  def preprocess(example):
49
  example['context&question']=example['context']+example['question']
@@ -51,17 +53,17 @@ def preprocess(example):
51
  example['second_sentence']=example['answer0'],example['answer1'],example['answer2'],example['answer3']
52
  return example
53
 
54
- run_dataset=run_dataset.map(preprocess)
55
 
56
  def tokenize(examples):
57
  a=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax')
58
  a['labels']=examples['label']
59
  return a
60
 
61
- run_dataset=run_dataset.map(tokenize)
62
 
63
- input_id=jnp.array(run_dataset['input_ids'])
64
- att_mask=jnp.array(run_dataset['attention_mask'])
65
 
66
  outputs=model(input_id,att_mask)
67
 
 
36
  ```bash
37
  # Installing requirements
38
  pip install transformers
39
+ pip install datasets
40
  ```
41
 
42
  ```python
43
  from model_file import FlaxGPT2ForMultipleChoice
44
+ from datasets import Dataset
45
  model_path="flax-community/gpt2-Cosmos"
46
  model = FlaxGPT2ForMultipleChoice.from_pretrained(model_path,input_shape=(1,4,1))
47
 
48
+ dataset=Dataset.from_csv('......')
49
 
50
  def preprocess(example):
51
  example['context&question']=example['context']+example['question']
 
53
  example['second_sentence']=example['answer0'],example['answer1'],example['answer2'],example['answer3']
54
  return example
55
 
56
+ dataset=dataset.map(preprocess)
57
 
58
  def tokenize(examples):
59
  a=tokenizer(examples['first_sentence'],examples['second_sentence'],padding='max_length',truncation=True,max_length=256,return_tensors='jax')
60
  a['labels']=examples['label']
61
  return a
62
 
63
+ dataset=dataset.map(tokenize)
64
 
65
+ input_id=jnp.array(dataset['input_ids'])
66
+ att_mask=jnp.array(dataset['attention_mask'])
67
 
68
  outputs=model(input_id,att_mask)
69