Saving weights of epoch 1 at step 92
Browse files- flax_model.msgpack +1 -1
- results_tensorboard/events.out.tfevents.1626415968.t1v-n-8cb15980-w-0.843954.3.v2 +3 -0
- results_tensorboard/events.out.tfevents.1626416209.t1v-n-8cb15980-w-0.845549.3.v2 +3 -0
- src/__pycache__/model_file.cpython-38.pyc +0 -0
- src/prediction.py +1 -0
- src/test.py +1 -0
- src/train.py +2 -2
flax_model.msgpack
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 1419367919
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:51d86bd352715e1623b69a8451f8c752c314bb6cf7669a5d9bb2f7589261d8c3
|
3 |
size 1419367919
|
results_tensorboard/events.out.tfevents.1626415968.t1v-n-8cb15980-w-0.843954.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c7ba3313ae634173f24d2dd23ebaba16f5e2889922e2e54aa323a38a726b39e1
|
3 |
+
size 40
|
results_tensorboard/events.out.tfevents.1626416209.t1v-n-8cb15980-w-0.845549.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:10501817e5f6a56a85f8dbb3b994e577cc9eb1629d28bb96dc8ad403f11b4dda
|
3 |
+
size 25152
|
src/__pycache__/model_file.cpython-38.pyc
ADDED
Binary file (9.01 kB). View file
|
|
src/prediction.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import jax
|
|
|
2 |
import jax.numpy as jnp
|
3 |
|
4 |
import flax
|
|
|
1 |
import jax
|
2 |
+
print(jax.local_device_count())
|
3 |
import jax.numpy as jnp
|
4 |
|
5 |
import flax
|
src/test.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import jax
|
|
|
2 |
import jax.numpy as jnp
|
3 |
|
4 |
import flax
|
|
|
1 |
import jax
|
2 |
+
print(jax.local_device_count())
|
3 |
import jax.numpy as jnp
|
4 |
|
5 |
import flax
|
src/train.py
CHANGED
@@ -111,7 +111,7 @@ def main():
|
|
111 |
def eval_function(logits):
|
112 |
return logits.argmax(-1)
|
113 |
|
114 |
-
model = FlaxGPT2ForMultipleChoice.from_pretrained('gpt2-
|
115 |
|
116 |
state=TrainState.create(apply_fn=model.__call__,
|
117 |
params=model.params,
|
@@ -238,4 +238,4 @@ def main():
|
|
238 |
summary_writer.flush()
|
239 |
|
240 |
if __name__ == "__main__":
|
241 |
-
main()
|
|
|
111 |
def eval_function(logits):
|
112 |
return logits.argmax(-1)
|
113 |
|
114 |
+
model = FlaxGPT2ForMultipleChoice.from_pretrained('gpt2-medium',input_shape=(1,4,1))
|
115 |
|
116 |
state=TrainState.create(apply_fn=model.__call__,
|
117 |
params=model.params,
|
|
|
238 |
summary_writer.flush()
|
239 |
|
240 |
if __name__ == "__main__":
|
241 |
+
main()
|