Vivek commited on
Commit
a25072c
·
1 Parent(s): 3adb47c

Saving weights of epoch 1 at step 92

Browse files
flax_model.msgpack CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:fd2e92a7233e29a509eb154637ad78cef153ebf1065ed18cb43fa82960412b88
3
  size 1419367919
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51d86bd352715e1623b69a8451f8c752c314bb6cf7669a5d9bb2f7589261d8c3
3
  size 1419367919
results_tensorboard/events.out.tfevents.1626415968.t1v-n-8cb15980-w-0.843954.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7ba3313ae634173f24d2dd23ebaba16f5e2889922e2e54aa323a38a726b39e1
3
+ size 40
results_tensorboard/events.out.tfevents.1626416209.t1v-n-8cb15980-w-0.845549.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10501817e5f6a56a85f8dbb3b994e577cc9eb1629d28bb96dc8ad403f11b4dda
3
+ size 25152
src/__pycache__/model_file.cpython-38.pyc ADDED
Binary file (9.01 kB). View file
 
src/prediction.py CHANGED
@@ -1,4 +1,5 @@
1
  import jax
 
2
  import jax.numpy as jnp
3
 
4
  import flax
 
1
  import jax
2
+ print(jax.local_device_count())
3
  import jax.numpy as jnp
4
 
5
  import flax
src/test.py CHANGED
@@ -1,4 +1,5 @@
1
  import jax
 
2
  import jax.numpy as jnp
3
 
4
  import flax
 
1
  import jax
2
+ print(jax.local_device_count())
3
  import jax.numpy as jnp
4
 
5
  import flax
src/train.py CHANGED
@@ -111,7 +111,7 @@ def main():
111
  def eval_function(logits):
112
  return logits.argmax(-1)
113
 
114
- model = FlaxGPT2ForMultipleChoice.from_pretrained('gpt2-large',input_shape=(1,4,1))
115
 
116
  state=TrainState.create(apply_fn=model.__call__,
117
  params=model.params,
@@ -238,4 +238,4 @@ def main():
238
  summary_writer.flush()
239
 
240
  if __name__ == "__main__":
241
- main()
 
111
  def eval_function(logits):
112
  return logits.argmax(-1)
113
 
114
+ model = FlaxGPT2ForMultipleChoice.from_pretrained('gpt2-medium',input_shape=(1,4,1))
115
 
116
  state=TrainState.create(apply_fn=model.__call__,
117
  params=model.params,
 
238
  summary_writer.flush()
239
 
240
  if __name__ == "__main__":
241
+ main()