Improve model card for MARS optimizer: add metadata, paper, code, and usage
Browse filesThis PR improves the model card for the MARS optimizer by:
- Adding `pipeline_tag: image-classification` to reflect the optimizer's evaluation on relevant vision tasks like mini-imagenet and CIFAR, helping users discover it when filtering for this pipeline.
- Adding `library_name: transformers`, as the optimizer is used to train models from the Hugging Face Transformers library (e.g., GPT-2), enabling a predefined code snippet for usage.
- Updating the main title of the model card to the paper's official title and linking it to the Hugging Face paper page for clearer attribution.
- Adding a prominent link to the GitHub repository for easy access to the source code.
- Expanding the model card content by integrating comprehensive sections from the official GitHub README, including "About MARS," "Instantiations," detailed "Performance Comparisons" across various tasks (LLMs and Vision), "Training GPT-2 from Scratch," and instructions for "Reproducing Results."
- Including a "Customized Training" section with a Python code snippet from the official repository, demonstrating how to integrate the MARS optimizer into a PyTorch training loop.
- Adding the official Citation and Acknowledgements.
These updates provide a more complete, discoverable, and user-friendly overview of the MARS optimizer for the Hugging Face community.
@@ -1,12 +1,91 @@
|
|
1 |
---
|
2 |
-
license: apache-2.0
|
3 |
datasets:
|
4 |
- timm/mini-imagenet
|
|
|
|
|
|
|
5 |
---
|
6 |
|
7 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
The runs were all performed training a smaller ViT (`vit_wee_patch16_reg1_gap_256`) for 200 epochs (10M samples seen) from scratch on the `timm` 'mini-imagenet' dataset, a 100 class subset of imagenet with same image sizes as originals.
|
12 |
|
@@ -27,16 +106,16 @@ Train args:
|
|
27 |
|
28 |
# LaProp
|
29 |
|
30 |
-
|optim
|
31 |
-
|
32 |
-
|claprop, lr=1e-03
|
33 |
-
|claprop, lr=5e-04
|
34 |
-
|laprop, lr=5e-04
|
35 |
-
|laprop, lr=1e-03
|
36 |
-
|claprop, lr=2e-04
|
37 |
-
|laprop, lr=2e-04
|
38 |
-
|claprop, lr=2e-03
|
39 |
-
|laprop, lr=2e-03
|
40 |
|
41 |
## LaProp Top-1 Evaluation Accuracy on Mini-ImageNet
|
42 |

|
@@ -46,16 +125,16 @@ Train args:
|
|
46 |
|
47 |
# AdamW
|
48 |
|
49 |
-
|optim
|
50 |
-
|
51 |
-
|cadamw, lr=1e-03
|
52 |
-
|cadamw, lr=5e-04
|
53 |
-
|cadamw, lr=1e-03, clip grads|203.0|2.1360626220703125|1.1043113907814026|73.33000158691407|91.41000042724608
|
54 |
-
|adamw, lr=1e-03, clip grads |195.0|2.2746386528015137|1.142998440361023 |72.11000151367188|90.47000052490236
|
55 |
-
|adamw, lr=5e-04
|
56 |
-
|adamw, lr=1e-03
|
57 |
-
|cadamw, lr=2e-04
|
58 |
-
|adamw, lr=2e-04
|
59 |
|
60 |
## AdamW Top-1 Evaluation Accuracy on Mini-ImageNet
|
61 |

|
@@ -65,14 +144,14 @@ Train args:
|
|
65 |
|
66 |
# MARS
|
67 |
|
68 |
-
|optim
|
69 |
-
|
70 |
-
|cmars, lr=1e-03|198.0
|
71 |
-
|cmars, lr=2e-03|203.0
|
72 |
-
|mars, lr=1e-03 |184.0
|
73 |
-
|mars, lr=2e-03 |197.0
|
74 |
-
|cmars, lr=5e-04|198.0
|
75 |
-
|mars, lr=5e-04 |189.0
|
76 |
|
77 |
|
78 |
## MARS Top-1 Evaluation Accuracy on Mini-ImageNet
|
@@ -81,3 +160,311 @@ Train args:
|
|
81 |
## MARS Train Loss
|
82 |

|
83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
|
|
2 |
datasets:
|
3 |
- timm/mini-imagenet
|
4 |
+
license: apache-2.0
|
5 |
+
pipeline_tag: image-classification
|
6 |
+
library_name: transformers
|
7 |
---
|
8 |
|
9 |
+
# [MARS: Unleashing the Power of Variance Reduction for Training Large Models](https://huggingface.co/papers/2411.10438)
|
10 |
+
|
11 |
+
**Code Repository**: https://github.com/AGI-Arena/MARS
|
12 |
+
|
13 |
+
## About MARS
|
14 |
+
|
15 |
+
**MARS** (**M**ake v**A**riance **R**eduction **S**hine) is a unified optimization framework designed to address the inherent challenges of training large models. Traditional adaptive gradient methods like Adam and AdamW often suffer from high stochastic gradient variance, while variance reduction techniques have struggled to gain practical impact in deep learning. At its core, **MARS** comprises two major components: (1) a scaled stochastic recursive momentum, which provides a variance-reduced estimator of the full gradient for better gradient complexity; and (2) the preconditioned update, which approximates the second-order Newton's method for better per-iteration complexity. By combining preconditioned gradient methods with variance reduction, **MARS** achieves the best of both worlds, accelerating the search for critical points in optimization.
|
16 |
+
|
17 |
+
The **MARS** framework is built on the following preconditioned variance-reduced updates
|
18 |
+
|
19 |
+
$$
|
20 |
+
\mathbf{c}\_t =
|
21 |
+
abla f(\mathbf{x}\_t, \mathbf{\xi}\_t)+\underbrace{{\color{red}\gamma_t} \frac{\beta_{1}}{1-\beta_{1}} \left(
|
22 |
+
abla f(\mathbf{x}\_t, \mathbf{\xi}\_t)-
|
23 |
+
abla f(\mathbf{x}\_{t-1}, \mathbf{\xi}\_t)\right)}_{\text{scaled gradient correction}}
|
24 |
+
$$
|
25 |
+
|
26 |
+
$$
|
27 |
+
\tilde{\mathbf{c}}_t = \text{Clip}(\mathbf{c}_t,1) = \begin{cases}
|
28 |
+
\frac{\mathbf{c}_t}{\|\mathbf{c}_t\|_2} & \text{if } \|\mathbf{c}_t\|_2 > 1,\\
|
29 |
+
\mathbf{c}_t & \text{otherwise}.
|
30 |
+
\end{cases}
|
31 |
+
$$
|
32 |
+
|
33 |
+
$$
|
34 |
+
\mathbf{m}\_t = \beta_1 \mathbf{m}\_{t-1} + (1-\beta_{1})\tilde{\mathbf{c}}\_t
|
35 |
+
$$
|
36 |
+
|
37 |
+
$$
|
38 |
+
\mathbf{x}\_{t+1} = \arg\min_{\mathbf{x} \in \mathbb{R}^d} \left\{\eta_t \left\langle \mathbf{m}_t, \mathbf{x} \right\rangle + \frac{1}{2} \|\mathbf{x} - \mathbf{x}\_t
|
39 |
+
\|_{\mathbf{H}_t}^2\right\}
|
40 |
+
$$
|
41 |
+
|
42 |
+
Here ${\color{red}\gamma_t}$ is a scaling parameter that controls the strength of gradient correction.
|
43 |
+
|
44 |
+
### Instantiations of **MARS**
|
45 |
+
|
46 |
+
Under the **MARS** framework, we provide three instantiations based on different Hessian matrix approximations: **MARS-AdamW**, **MARS-Lion**, and **MARS-Shampoo**. Please note that the hyperparameters in this framework are tuned on **MARS-AdamW**. When using other instantiations, it is essential to tune the hyperparameters—particularly the learning rates—for optimal performance.
|
47 |
+
|
48 |
+
#### MARS-AdamW
|
49 |
+
|
50 |
+
(Enable with `mars_type="mars-adamw"` in `mars.py`)
|
51 |
+
|
52 |
+
The Hessian matrix approximation is defined as:
|
53 |
+
|
54 |
+
$$
|
55 |
+
\mathbf{v}\_t =\beta_2 \mathbf{v}\_{t-1}+(1-\beta_2) \big(
|
56 |
+
abla f(\mathbf{x}\_t, \mathbf{\xi}\_t)\big)^2
|
57 |
+
$$
|
58 |
+
|
59 |
+
$$
|
60 |
+
\mathbf{H}_t := \sqrt{\text{diag}\Big(\mathbf{v}_t\Big)}\cdot \frac{1 - \beta_1^t}{\sqrt{1 - \beta_2^t}}.
|
61 |
+
$$
|
62 |
+
|
63 |
+
#### MARS-Lion
|
64 |
|
65 |
+
(Enable with `mars_type="mars-lion"` in `mars.py`)
|
66 |
+
|
67 |
+
The Hessian matrix approximation is defined as:
|
68 |
+
|
69 |
+
$$
|
70 |
+
\mathbf{H}_t := \sqrt{\text{diag}(\mathbf{m}_t^2)}.
|
71 |
+
$$
|
72 |
+
|
73 |
+
#### MARS-Shampoo
|
74 |
+
|
75 |
+
(Enable with `mars_type="mars-shampoo"` in `mars.py`)
|
76 |
+
|
77 |
+
The preconditioner can be seen as an [orthogonal mapping](https://arxiv.org/abs/2409.20325) operator:
|
78 |
+
|
79 |
+
$$
|
80 |
+
\mathbf{U}\_t, \mathbf{\Sigma}\_t, \mathbf{V}\_t = \text{SVD}(\mathbf{G}\_t),\qquad
|
81 |
+
\mathbf{x}\_{t+1} =\mathbf{x}\_t-\eta_t\mathbf{U}_t\mathbf{V}\_t^\top.
|
82 |
+
$$
|
83 |
+
|
84 |
+
In practice, we use the [Newton-Schulz iteration](https://github.com/KellerJordan/modded-nanogpt) to accelerate and approximate the solution of SVD problem.
|
85 |
+
|
86 |
+
## Comparisons of timm Optimizers w/ Caution
|
87 |
+
|
88 |
+
This section presents summaries of several sets of experiments comparing a number of optimizers with and without caution (https://huggingface.co/papers/2411.16085) enabled.
|
89 |
|
90 |
The runs were all performed training a smaller ViT (`vit_wee_patch16_reg1_gap_256`) for 200 epochs (10M samples seen) from scratch on the `timm` 'mini-imagenet' dataset, a 100 class subset of imagenet with same image sizes as originals.
|
91 |
|
|
|
106 |
|
107 |
# LaProp
|
108 |
|
109 |
+
|optim |best_epoch|train_loss |eval_loss |eval_top1 |eval_top5 |lr |
|
110 |
+
|---|---|---|---|---|---|---|\
|
111 |
+
|claprop, lr=1e-03 |204.0 |2.2173619270324707|1.0931779468536378|73.920000390625 |91.33000009765624|0.0 |\
|
112 |
+
|claprop, lr=5e-04 |183.0 |2.262192726135254 |1.0912627222061158|73.77000073242188|91.22000260009766|1.3478660293113704e-05|\
|
113 |
+
|laprop, lr=5e-04 |198.0 |2.2425642013549805|1.1426102781295775|71.73000213623047|90.55000146484376|1.109508849230001e-06 |\
|
114 |
+
|laprop, lr=1e-03 |179.0 |2.290040969848633 |1.168387135314941 |71.15000104980469|90.18000189208983|3.806023374435663e-05 |\
|
115 |
+
|claprop, lr=2e-04 |195.0 |2.546172380447388 |1.2475446645736694|68.30000163574219|89.15000153808593|9.97634228344235e-07 |\
|
116 |
+
|laprop, lr=2e-04 |204.0 |2.6702351570129395|1.309178423690796 |67.07999990234374|88.67000270996094|0.0 |\
|
117 |
+
|claprop, lr=2e-03 |193.0 |2.678058862686157 |1.5239886917114258|62.08000177001953|84.8 |1.4890673845226132e-05|\
|
118 |
+
|laprop, lr=2e-03 |200.0 |2.70467209815979 |1.522907255935669 |61.46000135498047|85.28000162353516|1.9732715717284413e-06|\
|
119 |
|
120 |
## LaProp Top-1 Evaluation Accuracy on Mini-ImageNet
|
121 |

|
|
|
125 |
|
126 |
# AdamW
|
127 |
|
128 |
+
|optim |best_epoch|train_loss |eval_loss |eval_top1 |eval_top5 |
|
129 |
+
|---|---|---|---|---|---|\
|
130 |
+
|cadamw, lr=1e-03 |184.0|2.2688851356506348|1.0868136840820313|73.52000141601563|91.60000036621092|\
|
131 |
+
|cadamw, lr=5e-04 |199.0|2.163278102874756 |1.0976034646987916|73.3900005859375 |91.31000137939454|\
|
132 |
+
|cadamw, lr=1e-03, clip grads|203.0|2.1360626220703125|1.1043113907814026|73.33000158691407|91.41000042724608|\
|
133 |
+
|adamw, lr=1e-03, clip grads |195.0|2.2746386528015137|1.142998440361023 |72.11000151367188|90.47000052490236|\
|
134 |
+
|adamw, lr=5e-04 |185.0|2.3040246963500977|1.1535791856765747|71.50000120849609|90.4800001953125 |\
|
135 |
+
|adamw, lr=1e-03 |199.0|2.223684310913086 |1.1657958560943604|71.22999993896484|90.30999958496092|\
|
136 |
+
|cadamw, lr=2e-04 |189.0|2.538627862930298 |1.2325929063796996|68.94999995117188|89.61000139160156|\
|
137 |
+
|adamw, lr=2e-04 |203.0|2.579624652862549 |1.3085522148132325|67.11000026855469|88.66000164794922|\
|
138 |
|
139 |
## AdamW Top-1 Evaluation Accuracy on Mini-ImageNet
|
140 |

|
|
|
144 |
|
145 |
# MARS
|
146 |
|
147 |
+
|optim |best_epoch|train_loss |eval_loss |eval_top1 |eval_top5 |
|
148 |
+
|---|---|---|---|---|---|\
|
149 |
+
|cmars, lr=1e-03|198.0 |2.054780960083008 |1.0435627010345458|74.91000185546875|92.08000146484376|\
|
150 |
+
|cmars, lr=2e-03|203.0 |2.0272469520568848|1.0705795244216918|74.31000185546876|91.54000092773435|\
|
151 |
+
|mars, lr=1e-03 |184.0 |2.219767808914185 |1.07215625667572 |74.06000178222656|91.6200013671875 |\
|
152 |
+
|mars, lr=2e-03 |197.0 |2.1453990936279297|1.0963781481742858|73.73000098876953|91.1500006225586 |\
|
153 |
+
|cmars, lr=5e-04|198.0 |2.2018630504608154|1.083557384109497 |73.32000045166015|91.67000092773438|\
|
154 |
+
|mars, lr=5e-04 |189.0 |2.322845220565796 |1.1199828132629397|72.02999995117187|90.86000173339843|\
|
155 |
|
156 |
|
157 |
## MARS Top-1 Evaluation Accuracy on Mini-ImageNet
|
|
|
160 |
## MARS Train Loss
|
161 |

|
162 |
|
163 |
+
### Performance of MARS Compared to Baselines
|
164 |
+
|
165 |
+
#### Experiments on OpenWebText
|
166 |
+
|
167 |
+
Experimental results for **MARS** are based on the **MARS-AdamW** instantiation, unless otherwise stated. In our experiments, gradients are calculated once per sample and per update (**MARS**-approx in our [paper](https://huggingface.co/papers/2411.10438)). Performing exact gradient computation with two evaluations per update, as in the exact form of **MARS**, can slightly enhance performance but at the cost of doubling the computational expense. For more details, refer to our [paper](https://huggingface.co/papers/2411.10438).
|
168 |
+
|
169 |
+
**MARS** consistently outperforms AdamW and the [Muon](https://github.com/KellerJordan/modded-nanogpt/tree/e01b457c7c52e1cd0c592920499a016f5289a69e) optimizers across GPT-2 models:
|
170 |
+
|
171 |
+
| **GPT-2 small** | **GPT-2 medium** | **GPT-2 large** |
|
172 |
+
|---|---|---|
|
173 |
+
| <img src="assets/val_small.png" width="350"> | <img src="assets/val_medium.png" width="350"> | <img src="assets/val_large.png" width="350"> |
|
174 |
+
|
175 |
+
| Best Val Loss | GPT-2 Small (5B tokens) | GPT-2 Medium (5B tokens) | GPT-2 Large (5B tokens) | GPT-2 Small (20B tokens) | GPT-2 Medium (20B tokens) | GPT-2 Large (20B tokens) | GPT-2 Small (50B tokens) | GPT-2 Medium (50B tokens) | GPT-2 Large (50B tokens) |
|
176 |
+
|---|---|---|---|---|---|---|---|---|---|\
|
177 |
+
| AdamW | 3.193 | 3.084 | 3.013 | 3.024 | 2.821 | 2.741 | 2.885 | 2.691 | 2.561 |\
|
178 |
+
| Muon | 3.165 | 3.009 | 2.915 | 3.006 | 2.813 | 2.691 | 2.901 | 2.688 | 2.573 |\
|
179 |
+
| **MARS**-exact | **3.107** | - | - | 2.980 | - | - | **2.847** | - | - |\
|
180 |
+
| **MARS**-approx | 3.108 | **2.969** | **2.876** | **2.981** | **2.763** | **2.647** | **2.849** | **2.636** | **2.518** |\
|
181 |
+
|
182 |
+
#### Efficiency of MARS
|
183 |
+
|
184 |
+
The **MARS** algorithm can achieve better performance not only within the same number of training steps, but also within the same training time:
|
185 |
+
|
186 |
+
| **GPT-2 small** | **GPT-2 medium** | **GPT-2 large** |
|
187 |
+
|---|---|---|
|
188 |
+
| <img src="assets/time_small.png" width="350"> | <img src="assets/time_medium.png" width="350"> | <img src="assets/time_large.png" width="350"> |
|
189 |
+
|
190 |
+
---
|
191 |
+
|
192 |
+
#### Experiments on FineWeb-Edu
|
193 |
+
|
194 |
+
Below are the training and validation loss curves for both GPT‑2 Small and GPT‑2 XL when using our MARS approach versus AdamW. As you can see, MARS often yields faster convergence and consistently lower losses across different training steps.
|
195 |
+
|
196 |
+
| Model | **GPT-2 small** | **GPT-2 XL** |
|
197 |
+
|---|---|---|
|
198 |
+
| **Train Loss** | <img src="assets/small_train.png" width="350"> | <img src="assets/xl_train.png" width="350"> |
|
199 |
+
| **Validation Loss** | <img src="assets/small_val.png" width="350"> | <img src="assets/xl_val.png" width="350"> |
|
200 |
+
|
201 |
+
##### Evaluation Metrics
|
202 |
+
Below, we present the evaluation metrics on the FineWeb-Edu dataset for both GPT‑2 Small and GPT‑2 XL, comparing OpenAI GPT2 baseline, AdamW, and our MARS-AdamW optimizer.
|
203 |
+
|
204 |
+
<img src="assets/fineweb_hella.png" width="350">
|
205 |
+
|
206 |
+
**Results on GPT-2 small**
|
207 |
+
|
208 |
+
MARS-AdamW shows a clear improvement over AdamW and the OpenAI baseline across multiple tasks, with the **highest average score** of 45.93 on GPT‑2 Small.
|
209 |
+
| Method/Task | ARC-E | ARC-C | BoolQ | HellaSwag | OBQA | PIQA | WG | MMLU | SciQ | Avg. |
|
210 |
+
|---|---|---|---|---|---|---|---|---|---|---|\
|
211 |
+
| OpenAI-Comm. | 39.48 | 22.70 | 48.72 | 31.14 | 27.20 | 62.51 | **51.62** | 22.92 | 64.40 | 41.19 |\
|
212 |
+
| AdamW | 51.43 | 26.54 | 55.78 | 36.26 | 30.60 | 64.53 | 50.36 | **24.49** | **71.50** | 45.72 |\
|
213 |
+
| MARS-AdamW | **52.23** | **27.39** | **55.84** | **36.91** | **32.20** | **64.80** | 49.96 | 22.95 | 71.10 | **45.93** |\
|
214 |
+
|
215 |
+
**Results on GPT-2 XL**
|
216 |
+
|
217 |
+
On GPT‑2 XL, MARS-AdamW continues to outperform AdamW across most tasks, delivering an impressive **HellaSwag accuracy of 56.52**.
|
218 |
+
|
219 |
+
| Method/Task | ARC-E | ARC-C | BoolQ | HellaSwag | OBQA | PIQA | WG | MMLU | SciQ | Avg. |
|
220 |
+
|---|---|---|---|---|---|---|---|---|---|---|\
|
221 |
+
| OpenAI-Comm. | 51.05 | 28.50 | 61.77 | 50.89 | 32.00 | 70.51 | **58.33** | 25.24 | 76.00 | 50.48 |\
|
222 |
+
| AdamW | **68.22** | 38.40 | 61.13 | 53.93 | 39.00 | 72.69 | 54.78 | **25.47** | 85.30 | 55.43 |\
|
223 |
+
| MARS-AdamW | 66.54 | **39.85** | **63.82** | **56.52** | **41.20** | **73.34** | 56.59 | 23.86 | **86.00** | **56.41** |\
|
224 |
+
|
225 |
+
---
|
226 |
+
|
227 |
+
#### Experiments on Vision Tasks
|
228 |
+
|
229 |
+
**MARS** can achieve better test loss and accuracy than AdamW and the [Muon](https://github.com/KellerJordan/modded-nanogpt/tree/e01b457c7c52e1cd0c592920499a016f5289a69e) optimizers on CIFAR-10 and CIFAR-100 datasets with ResNet-18 and MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1) scheduler (We display the best results for each optimizer with grid search of base learning rate within [1e-5, ..., 1e-1]):
|
230 |
+
|
231 |
+
| Dataset | **CIFAR-10** | **CIFAR-100** |
|
232 |
+
|---|---|---|
|
233 |
+
| **Test loss** | <img src="assets/cifar10_test_loss.png" width="350"> | <img src="assets/cifar100_test_loss.png" width="350"> |
|
234 |
+
| **Test Accuracy** | <img src="assets/cifar10_test_acc.png" width="350"> | <img src="assets/cifar100_test_acc.png" width="350"> |
|
235 |
+
|
236 |
+
| Best Test loss | CIFAR-10 | CIFAR-100 |
|
237 |
+
|---|---|---|\
|
238 |
+
| AdamW | 0.306 | 2.608 |\
|
239 |
+
| Muon | 0.230 | 1.726 |\
|
240 |
+
| **MARS**-approx | **0.199** | **0.971** |\
|
241 |
+
|
242 |
+
| Best Test Accuracy (%) | CIFAR-10 | CIFAR-100 |
|
243 |
+
|---|---|---|\
|
244 |
+
| AdamW | 94.81 | 73.7 |\
|
245 |
+
| Muon | 95.08 | 74.64 |\
|
246 |
+
| **MARS**-approx | **95.29** | **76.97** |\
|
247 |
+
|
248 |
+
## Training GPT-2 from Scratch:
|
249 |
+
|
250 |
+
### Install Dependencies
|
251 |
+
|
252 |
+
```
|
253 |
+
$ pip install torch==2.1.2 transformers==4.33.0 datasets tiktoken numpy==1.26.4 wandb
|
254 |
+
```
|
255 |
+
|
256 |
+
### Data Preparation
|
257 |
+
|
258 |
+
Prepare the [OpenWebText](https://huggingface.co/datasets/openwebtext) data following [nanoGPT](https://github.com/karpathy/nanoGPT/):
|
259 |
+
|
260 |
+
```
|
261 |
+
$ python data/openwebtext/prepare.py
|
262 |
+
```
|
263 |
+
|
264 |
+
### **Start Training**
|
265 |
+
|
266 |
+
To train a model using the **MARS** optimizer, run the following command:
|
267 |
+
|
268 |
+
```bash
|
269 |
+
$ torchrun --standalone --nproc_per_node=8 MARS/train_mars.py config/${your_config_file}
|
270 |
+
```
|
271 |
+
|
272 |
+
This command initiates the training of a GPT-2 model on the OpenWebText dataset using the **MARS** optimizer. All relevant hyperparameters—training, model, and optimizer—are specified in the configuration file (`${your_config_file}`). These parameters can be adjusted directly in the configuration file or through the bash script.
|
273 |
+
|
274 |
+
### **Hyperparameter Details**
|
275 |
+
|
276 |
+
#### **Model Hyperparameters**:
|
277 |
+
|
278 |
+
- **n_layer**: Layers of networks, 12 for GPT2 Small, 24 for GPT2 Medium, 36 for GPT2 Large
|
279 |
+
- **n_head**: Number of heads, 12 for GPT2 small, 16 for GPT2 Medium, 20 for GPT2 Large
|
280 |
+
- **n_embd**: Embedding dimension, 768 for GPT2 small, 1024 for GPT2 Medium, 1280 for GPT2 Large
|
281 |
+
|
282 |
+
#### **Optimizer Hyperparameters**:
|
283 |
+
|
284 |
+
- **`learning_rate`**: Learning rate for the **MARS** optimizer.
|
285 |
+
- **`weight_decay`**: Weight decay for the **MARS** optimizer.
|
286 |
+
- **`beta1, beta2`**: Weights for exponential moving average.
|
287 |
+
- Default: `beta1=0.95, beta2=0.99`
|
288 |
+
- **`mars_type`**: Type of optimizer to use:
|
289 |
+
- Options: `mars-adamw`, `mars-lion`, `mars-shampoo`
|
290 |
+
- Default: `mars-adamw`
|
291 |
+
- **`optimize_1d`**: Whether **MARS** should optimize 1D parameters (e.g., layer norm parameters in GPT-2).
|
292 |
+
- If `False`, AdamW will be used for optimizing 1D parameters.
|
293 |
+
- Default: `False`
|
294 |
+
- **`lr_1d`**: Learning rate for AdamW when **`optimize_1d`** is set to `False`.
|
295 |
+
- **`betas_1d`**: Weights for exponential moving average in AdamW optimizer.
|
296 |
+
- Default: `(0.9, 0.95)`
|
297 |
+
- **`is_approx`**: Whether to use approximate gradient calculation (**MARS**-approx).
|
298 |
+
- Default: `True`
|
299 |
+
- **`gamma`**: The scaling parameter that controls the strength of gradient correction.
|
300 |
+
- Default: 0.025
|
301 |
+
|
302 |
+
#### **Training Hyperparameters**:
|
303 |
+
|
304 |
+
- **`batch_size`**: Mini-batch size per device. (for example GPT-2 Small on an A100 GPU typically uses a batch size of 15.)
|
305 |
+
- **`gradient_accumulation_steps`**: Gradient accumulation steps to ensure the total effective batch size matches the desired scale. (for example, for a total batch size of 480: $15 \times 4 \times 8 \, \text{GPUs}$.)
|
306 |
+
- **`schedule`**: learning rate schedule.
|
307 |
+
- Default: `cosine`
|
308 |
+
|
309 |
+
For more detailed hyperparameter examples, refer to:
|
310 |
+
|
311 |
+
- `config/train_gpt2_small_mars.py`
|
312 |
+
- `scripts/run_mars_small.sh`
|
313 |
+
|
314 |
+
---
|
315 |
+
|
316 |
+
### Reproducing Our Results
|
317 |
+
|
318 |
+
#### **Reproducing GPT-2 Small (125M) Results**
|
319 |
+
|
320 |
+
Training with MARS using
|
321 |
+
|
322 |
+
```
|
323 |
+
$ bash scripts/run_mars_small.sh
|
324 |
+
```
|
325 |
+
|
326 |
+
or
|
327 |
+
|
328 |
+
```
|
329 |
+
$ torchrun --standalone --nproc_per_node=8 \
|
330 |
+
MARS/train_mars.py \
|
331 |
+
config/train_gpt2_small_mars.py \
|
332 |
+
--batch_size=15 \
|
333 |
+
--gradient_accumulation_steps=4
|
334 |
+
```
|
335 |
+
|
336 |
+
#### Reproducing GPT2 Medium (355M) Results
|
337 |
+
|
338 |
+
Training with MARS using
|
339 |
+
|
340 |
+
```
|
341 |
+
$ bash scripts/run_mars_medium.sh
|
342 |
+
```
|
343 |
+
|
344 |
+
or
|
345 |
+
|
346 |
+
```
|
347 |
+
$ torchrun --standalone --nproc_per_node=8 \
|
348 |
+
MARS/train_mars.py \
|
349 |
+
config/train_gpt2_medium_mars.py \
|
350 |
+
--batch_size=15 \
|
351 |
+
--gradient_accumulation_steps=4
|
352 |
+
```
|
353 |
+
|
354 |
+
#### Reproducing GPT2 Large (770M) Results
|
355 |
+
|
356 |
+
Training with MARS using
|
357 |
+
|
358 |
+
```
|
359 |
+
$ bash scripts/run_mars_large.sh
|
360 |
+
```
|
361 |
+
|
362 |
+
or
|
363 |
+
|
364 |
+
```
|
365 |
+
$ torchrun --standalone --nproc_per_node=8 \
|
366 |
+
MARS/train_mars.py \
|
367 |
+
config/train_gpt2_large_mars.py \
|
368 |
+
--batch_size=5 \
|
369 |
+
--gradient_accumulation_steps=12
|
370 |
+
```
|
371 |
+
|
372 |
+
#### **Reproducing GPT-2 XL (1.5B) Results on FineWeb-Edu**
|
373 |
+
```
|
374 |
+
$ bash scripts/run_mars_xl_fw.sh
|
375 |
+
```
|
376 |
+
|
377 |
+
or
|
378 |
+
|
379 |
+
```
|
380 |
+
$ torchrun --standalone --nproc_per_node=8 \
|
381 |
+
MARS/train_mars_fw.py \
|
382 |
+
config/train_gpt2_xl_mars.py \
|
383 |
+
--batch_size=5 \
|
384 |
+
--gradient_accumulation_steps=12
|
385 |
+
```
|
386 |
+
|
387 |
+
#### Reproducing Baseline Results
|
388 |
+
|
389 |
+
To reproduce the AdamW baseline:
|
390 |
+
|
391 |
+
```
|
392 |
+
bash scripts/run_adamw_{small/medium/large}.sh
|
393 |
+
```
|
394 |
+
To reproduce the AdamW baseline on FineWeb-Edu:
|
395 |
+
```
|
396 |
+
bash scripts/run_adamw_{small/xl}_fw.sh
|
397 |
+
```
|
398 |
+
|
399 |
+
To reproduce the Muon baseline following [modded-nanogpt](https://github.com/KellerJordan/modded-nanogpt/tree/e01b457c7c52e1cd0c592920499a016f5289a69e):
|
400 |
+
|
401 |
+
```
|
402 |
+
bash scripts/run_muon_{small/medium/large}.sh
|
403 |
+
```
|
404 |
+
|
405 |
+
Please adjust ``nproc_per_node``, ``batch_size``, and ``gradient_accumulation_steps`` accordingly if you use other hardware setup. Make sure their product equals 480.
|
406 |
+
|
407 |
+
#### Hyperparameters for GPT-2 models
|
408 |
+
|
409 |
+
| Model Name | Model Size | lr for AdamW | lr for Muon | lr for MARS | lr_1d for MARS | wd for AdamW | wd for Muon | wd for MARS |
|
410 |
+
|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|\
|
411 |
+
| GPT-2 small | 125M | 6e-4 | 2e-2 | 6e-3 | 3e-3 | 1e-1 | 0.0 | 1e-2 |\
|
412 |
+
| GPT-2 medium | 355M | 3e-4 | 1e-2 | 3e-3 | 1.5e-3 | 1e-1 | 0.0 | 1e-2 |\
|
413 |
+
| GPT-2 large | 770M | 2e-4 | 6.67e-3 | 2e-3 | 1e-3 | 1e-1 | 0.0 | 1e-2 |\
|
414 |
+
| GPT-2 xl | 1.5B | 2e-4 | - | 2e-3 | 1e-3 | 1e-1 | - | 1e-2 |\
|
415 |
+
|
416 |
+
|
417 |
+
### Customized Training
|
418 |
+
|
419 |
+
To build your own training pipeline on other architectures and datasets, use the following template as an example:
|
420 |
+
|
421 |
+
```python
|
422 |
+
import torch
|
423 |
+
import torch.nn.functional as F
|
424 |
+
from mars import MARS
|
425 |
+
|
426 |
+
# init model loss function and input data
|
427 |
+
model = Model()
|
428 |
+
data_loader = ...
|
429 |
+
|
430 |
+
# init the optimizer
|
431 |
+
optimizer = MARS(model.parameters(), lr=1e-3, betas=(0.9, 0.95), gamma=0.025)
|
432 |
+
|
433 |
+
total_bs = len(data_loader)
|
434 |
+
bs = total_bs * block_size
|
435 |
+
k = 10
|
436 |
+
iter_num = -1
|
437 |
+
|
438 |
+
# training loop
|
439 |
+
for epoch in range(epochs):
|
440 |
+
for X, Y in data_loader:
|
441 |
+
# standard training code
|
442 |
+
logits, loss = model(X, Y)
|
443 |
+
loss.backward()
|
444 |
+
optimizer.step(bs=bs)
|
445 |
+
optimizer.zero_grad(set_to_none=True)
|
446 |
+
optimizer.update_last_grad()
|
447 |
+
iter_num += 1
|
448 |
+
|
449 |
+
```
|
450 |
+
|
451 |
+
## Star History
|
452 |
+
|
453 |
+
[](https://www.star-history.com/#AGI-Arena/MARS&Date)
|
454 |
+
|
455 |
+
## Citation
|
456 |
+
|
457 |
+
If you find this repo useful for your research, please consider citing the paper
|
458 |
+
|
459 |
+
```tex
|
460 |
+
@article{yuan2024mars,
|
461 |
+
title={MARS: Unleashing the Power of Variance Reduction for Training Large Models},
|
462 |
+
author={Yuan, Huizhuo and Liu, Yifeng and Wu, Shuang and Zhou, Xun and Gu, Quanquan},
|
463 |
+
journal={arXiv preprint arXiv:2411.10438},
|
464 |
+
year={2024}
|
465 |
+
}
|
466 |
+
```
|
467 |
+
|
468 |
+
## Acknowledgements
|
469 |
+
|
470 |
+
This repo is built upon [nanoGPT](https://github.com/karpathy/nanoGPT/), [levanter](https://github.com/stanford-crfm/levanter/) and [Sophia](https://github.com/Liuhong99/Sophia), we thank the authors for their great work!
|