|  | from __future__ import annotations | 
					
						
						|  |  | 
					
						
						|  | from dataclasses import asdict, dataclass, field | 
					
						
						|  | from glob import glob | 
					
						
						|  | from pathlib import Path | 
					
						
						|  | from typing import ( | 
					
						
						|  | Any, | 
					
						
						|  | Dict, | 
					
						
						|  | Iterable, | 
					
						
						|  | List, | 
					
						
						|  | Optional, | 
					
						
						|  | Tuple, | 
					
						
						|  | Type, | 
					
						
						|  | TypeVar, | 
					
						
						|  | Union, | 
					
						
						|  | cast, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  | from omegaconf import DictConfig, ListConfig | 
					
						
						|  | from omegaconf import OmegaConf as om | 
					
						
						|  | from omegaconf.errors import OmegaConfBaseException | 
					
						
						|  | from torch.distributed.fsdp import MixedPrecision, ShardingStrategy | 
					
						
						|  |  | 
					
						
						|  | from .aliases import PathOrStr | 
					
						
						|  | from .beam_search import Sampler | 
					
						
						|  | from .exceptions import OLMoConfigurationError | 
					
						
						|  | from .util import StrEnum | 
					
						
						|  |  | 
					
						
						|  | __all__ = [ | 
					
						
						|  | "ActivationType", | 
					
						
						|  | "ActivationCheckpointingStrategy", | 
					
						
						|  | "BlockType", | 
					
						
						|  | "LayerNormType", | 
					
						
						|  | "InitFnType", | 
					
						
						|  | "ModelConfig", | 
					
						
						|  | "OptimizerType", | 
					
						
						|  | "OptimizerConfig", | 
					
						
						|  | "SchedulerType", | 
					
						
						|  | "SchedulerConfig", | 
					
						
						|  | "DataConfig", | 
					
						
						|  | "EvaluatorConfig", | 
					
						
						|  | "TokenizerConfig", | 
					
						
						|  | "TrainConfig", | 
					
						
						|  | "PaddingDirection", | 
					
						
						|  | "TruncationDirection", | 
					
						
						|  | "SpeedMonitorConfig", | 
					
						
						|  | "WandbConfig", | 
					
						
						|  | "CompilerConfig", | 
					
						
						|  | "WandbConfig", | 
					
						
						|  | "FSDPPrecision", | 
					
						
						|  | "FSDPWrapStrategy", | 
					
						
						|  | "FSDPConfig", | 
					
						
						|  | "CheckpointType", | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | C = TypeVar("C", bound="BaseConfig") | 
					
						
						|  | D = TypeVar("D", bound="DictConfig|ListConfig") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BaseConfig: | 
					
						
						|  | @classmethod | 
					
						
						|  | def _register_resolvers(cls, validate_paths: bool = True): | 
					
						
						|  |  | 
					
						
						|  | def path_glob(*paths) -> List[str]: | 
					
						
						|  | out = [] | 
					
						
						|  | for path in paths: | 
					
						
						|  | matches = sorted(glob(path)) | 
					
						
						|  | if not matches and validate_paths: | 
					
						
						|  | raise FileNotFoundError(f"{path} does not match any files or dirs") | 
					
						
						|  | out.extend(matches) | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def path_choose(*paths) -> str: | 
					
						
						|  | from .util import is_url | 
					
						
						|  |  | 
					
						
						|  | for path in paths: | 
					
						
						|  | if is_url(path) or Path(path).exists(): | 
					
						
						|  | return path | 
					
						
						|  | if validate_paths: | 
					
						
						|  | raise FileNotFoundError(", ".join(paths)) | 
					
						
						|  | else: | 
					
						
						|  | return "" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def path_last_checkpoint(path) -> str: | 
					
						
						|  | from .util import find_latest_checkpoint | 
					
						
						|  |  | 
					
						
						|  | latest_checkpoint = find_latest_checkpoint(path) | 
					
						
						|  | if latest_checkpoint is None: | 
					
						
						|  | if validate_paths: | 
					
						
						|  | raise FileNotFoundError(f"Could not find a latest checkpoint at {path}") | 
					
						
						|  | else: | 
					
						
						|  | return "" | 
					
						
						|  | else: | 
					
						
						|  | return str(latest_checkpoint) | 
					
						
						|  |  | 
					
						
						|  | om.register_new_resolver("path.glob", path_glob, replace=True) | 
					
						
						|  | om.register_new_resolver("path.choose", path_choose, replace=True) | 
					
						
						|  | om.register_new_resolver("path.last_checkpoint", path_last_checkpoint, replace=True) | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def update_legacy_settings(cls, config: D) -> D: | 
					
						
						|  | """ | 
					
						
						|  | Update the legacy config settings whose schemas have undergone backwards-incompatible changes. | 
					
						
						|  | """ | 
					
						
						|  | return config | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def new(cls: Type[C], **kwargs) -> C: | 
					
						
						|  | cls._register_resolvers() | 
					
						
						|  | conf = om.structured(cls) | 
					
						
						|  | try: | 
					
						
						|  | if kwargs: | 
					
						
						|  | conf = om.merge(conf, kwargs) | 
					
						
						|  | return cast(C, om.to_object(conf)) | 
					
						
						|  | except OmegaConfBaseException as e: | 
					
						
						|  | raise OLMoConfigurationError(str(e)) | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def load( | 
					
						
						|  | cls: Type[C], | 
					
						
						|  | path: PathOrStr, | 
					
						
						|  | overrides: Optional[List[str]] = None, | 
					
						
						|  | key: Optional[str] = None, | 
					
						
						|  | validate_paths: bool = True, | 
					
						
						|  | ) -> C: | 
					
						
						|  | """Load from a YAML file.""" | 
					
						
						|  | cls._register_resolvers(validate_paths=validate_paths) | 
					
						
						|  | schema = om.structured(cls) | 
					
						
						|  | try: | 
					
						
						|  | raw = om.load(str(path)) | 
					
						
						|  | if key is not None: | 
					
						
						|  | raw = raw[key] | 
					
						
						|  | raw = cls.update_legacy_settings(raw) | 
					
						
						|  | conf = om.merge(schema, raw) | 
					
						
						|  | if overrides: | 
					
						
						|  | conf = om.merge(conf, om.from_dotlist(overrides)) | 
					
						
						|  | return cast(C, om.to_object(conf)) | 
					
						
						|  | except OmegaConfBaseException as e: | 
					
						
						|  | raise OLMoConfigurationError(str(e)) | 
					
						
						|  |  | 
					
						
						|  | def save(self, path: PathOrStr) -> None: | 
					
						
						|  | """Save to a YAML file.""" | 
					
						
						|  | om.save(config=self, f=str(path)) | 
					
						
						|  |  | 
					
						
						|  | def asdict(self, exclude: Optional[Iterable[str]] = None) -> Dict[str, Any]: | 
					
						
						|  | out = asdict(self) | 
					
						
						|  | if exclude is not None: | 
					
						
						|  | for name in exclude: | 
					
						
						|  | if name in out: | 
					
						
						|  | del out[name] | 
					
						
						|  | return out | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class LayerNormType(StrEnum): | 
					
						
						|  | default = "default" | 
					
						
						|  | """ | 
					
						
						|  | The default LayerNorm implementation, equivalent to PyTorch's built-in version. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | low_precision = "low_precision" | 
					
						
						|  | """ | 
					
						
						|  | A low-precision version of the default LayerNorm. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | rms = "rms" | 
					
						
						|  | """ | 
					
						
						|  | An RMSNorm implementation. When using ``torch.compile`` this is | 
					
						
						|  | probably the fastest implementation. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ActivationType(StrEnum): | 
					
						
						|  | gelu = "gelu" | 
					
						
						|  | relu = "relu" | 
					
						
						|  | swiglu = "swiglu" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class BlockType(StrEnum): | 
					
						
						|  | sequential = "sequential" | 
					
						
						|  |  | 
					
						
						|  | llama = "llama" | 
					
						
						|  | """ | 
					
						
						|  | A block similar to the sequential block with slightly different | 
					
						
						|  | implementations of operations like attention to imitate the behavior of Llama. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class InitFnType(StrEnum): | 
					
						
						|  | mitchell = "mitchell" | 
					
						
						|  | """ | 
					
						
						|  | The strategy suggested to us by Mitchell Wortsman from UW. | 
					
						
						|  | This uses a truncated normal distribution with an adaptive standard deviation that depends | 
					
						
						|  | on the size of the weights as well as the depth of the layer. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | normal = "normal" | 
					
						
						|  | """ | 
					
						
						|  | All weights are initialized from the same normal distribution. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | kaiming_normal = "kaiming_normal" | 
					
						
						|  | """ | 
					
						
						|  | All weights are initialized with the Kaiming method from a normal distribution. | 
					
						
						|  | Note this currently won't work with FSDP. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | fan_in = "fan_in" | 
					
						
						|  | """ | 
					
						
						|  | "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in`` | 
					
						
						|  | is the input dimensionality of the kernel. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | full_megatron = "full_megatron" | 
					
						
						|  | """ | 
					
						
						|  | This is what metaseq calls "full megatron init". It is the init used for Llama 2. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class ModelConfig(BaseConfig): | 
					
						
						|  | """ | 
					
						
						|  | OLMo (model) configuration. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | d_model: int = 768 | 
					
						
						|  | """ | 
					
						
						|  | The hidden size of the model. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | n_heads: int = 12 | 
					
						
						|  | """ | 
					
						
						|  | The number of self-attention heads. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | n_kv_heads: Optional[int] = None | 
					
						
						|  | """ | 
					
						
						|  | The number of heads to use for keys and values. Defaults to `n_heads`. | 
					
						
						|  | Set this to ``None`` or ``n_heads`` for normal multi-head attention. | 
					
						
						|  | Set this to 1 for multi-query attention. | 
					
						
						|  | Set it to some in-between value for Llama2-style grouped query attention. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | clip_qkv: Optional[float] = None | 
					
						
						|  | """ | 
					
						
						|  | Clip QKV to this value when set. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | n_layers: int = 12 | 
					
						
						|  | """ | 
					
						
						|  | The number of layers/blocks. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | mlp_ratio: int = 4 | 
					
						
						|  | """ | 
					
						
						|  | The ratio of the inner MLP dimensionality to ``d_model``. | 
					
						
						|  | This is only used when ``mlp_hidden_size`` is not set. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | mlp_hidden_size: Optional[int] = None | 
					
						
						|  | """ | 
					
						
						|  | Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | activation_type: ActivationType = ActivationType.swiglu | 
					
						
						|  | """ | 
					
						
						|  | The activation function to use within the MLP layers. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | block_type: BlockType = BlockType.sequential | 
					
						
						|  | """ | 
					
						
						|  | The transformer block implementation. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | block_group_size: int = 1 | 
					
						
						|  | """ | 
					
						
						|  | The number of blocks to group together into a single parent block. | 
					
						
						|  | This has no affect on the number of parameters in the model and is only used to wrap groups | 
					
						
						|  | of blocks together with a single FSDP wrapper during training. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | alibi: bool = False | 
					
						
						|  | """ | 
					
						
						|  | If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | alibi_bias_max: float = 8.0 | 
					
						
						|  | """ | 
					
						
						|  | Maximum absolute value of ALiBi bias. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | rope: bool = False | 
					
						
						|  | """ | 
					
						
						|  | Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | rope_full_precision: bool = True | 
					
						
						|  | """ | 
					
						
						|  | If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise, | 
					
						
						|  | apply RoPE at the precision of the input. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | flash_attention: bool = False | 
					
						
						|  | """ | 
					
						
						|  | If ``True``, use ``FlashAttention``. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | attention_dropout: float = 0.1 | 
					
						
						|  | """ | 
					
						
						|  | The dropout probability within the attention modules. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | multi_query_attention: Optional[bool] = None | 
					
						
						|  | """ | 
					
						
						|  | Deprecated. Use n_kv_heads instead. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | attention_layer_norm: bool = False | 
					
						
						|  | """ | 
					
						
						|  | Apply layer norm to the keys and queries within the attention mechanism. | 
					
						
						|  | This can help stabilize training. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | residual_dropout: float = 0.1 | 
					
						
						|  | """ | 
					
						
						|  | The dropout probability for the MLP and attention output within each block. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | embedding_dropout: float = 0.1 | 
					
						
						|  | """ | 
					
						
						|  | The dropout probability for embeddings. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | layer_norm_type: LayerNormType = LayerNormType.default | 
					
						
						|  | """ | 
					
						
						|  | The layernorm implementation to use. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | layer_norm_with_affine: bool = True | 
					
						
						|  | """ | 
					
						
						|  | Whether to include bias and weight parameters for the layer norms. | 
					
						
						|  | This only affects layer norms that are immediately followed by a linear layer in the forward pass, | 
					
						
						|  | so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine` | 
					
						
						|  | to ``False``. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | attention_layer_norm_with_affine: bool = True | 
					
						
						|  | """ | 
					
						
						|  | Toggle affine transform for the QK norms. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | max_sequence_length: int = 1024 | 
					
						
						|  | """ | 
					
						
						|  | The maximum input sequence length supported by the model. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | include_bias: bool = True | 
					
						
						|  | """ | 
					
						
						|  | Whether or not to include bias parameters in linear layers. | 
					
						
						|  | In PaLM, they got rid of all bias terms because they found that large | 
					
						
						|  | models tend to have near 0 bias terms anyway. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | bias_for_layer_norm: Optional[bool] = None | 
					
						
						|  | """ | 
					
						
						|  | Whether or not to include bias parameters in layer norm. | 
					
						
						|  | This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in | 
					
						
						|  | layer norm. | 
					
						
						|  | When this is None (the default), it inherits the setting from include_bias. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | scale_logits: bool = False | 
					
						
						|  | """ | 
					
						
						|  | If ``True``, scale the output logits by ``1 / sqrt(d_model)``. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | vocab_size: int = 50257 | 
					
						
						|  | """ | 
					
						
						|  | Vocabulary size of the model. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | embedding_size: Optional[int] = 50304 | 
					
						
						|  | """ | 
					
						
						|  | The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default | 
					
						
						|  | to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the | 
					
						
						|  | next multiple of 128 that's greater than ``vocab_size`` can improve throughput | 
					
						
						|  | substantially. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | weight_tying: bool = True | 
					
						
						|  | """ | 
					
						
						|  | Whether to tie output linear weights to the input embedding. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | eos_token_id: int = 50256 | 
					
						
						|  | """ | 
					
						
						|  | The ID of the end-of-sentence special token. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | pad_token_id: int = 50256 | 
					
						
						|  | """ | 
					
						
						|  | The ID of the token to use for padding. Defaults to the ID of the EOS token. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | init_device: Optional[str] = None | 
					
						
						|  | """ | 
					
						
						|  | The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta". | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | init_fn: InitFnType = InitFnType.normal | 
					
						
						|  | """ | 
					
						
						|  | The weight initialization strategy. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | init_std: float = 0.02 | 
					
						
						|  | """ | 
					
						
						|  | The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such | 
					
						
						|  | as "normal". | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | init_cutoff_factor: Optional[float] = None | 
					
						
						|  | """ | 
					
						
						|  | A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such | 
					
						
						|  | as "normal". Setting this to None means values are not cutoff. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | precision: Optional[str] = None | 
					
						
						|  | """ | 
					
						
						|  | Precision used to train/evaluate with. You shouldn't set this directly. | 
					
						
						|  | See :data:`TrainConfig.precision` instead. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | ternary: bool = False | 
					
						
						|  | """ | 
					
						
						|  | Use ternary BitLinear layer from "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits" (https://arxiv.org/pdf/2402.17764.pdf) | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def effective_n_kv_heads(self) -> int: | 
					
						
						|  | if self.n_kv_heads is None: | 
					
						
						|  | if self.multi_query_attention is True: | 
					
						
						|  | return 1 | 
					
						
						|  | else: | 
					
						
						|  | return self.n_heads | 
					
						
						|  | else: | 
					
						
						|  | if self.multi_query_attention is None: | 
					
						
						|  | return self.n_kv_heads | 
					
						
						|  | if self.multi_query_attention: | 
					
						
						|  | n_kv_heads_should_be = 1 | 
					
						
						|  | else: | 
					
						
						|  | n_kv_heads_should_be = self.n_heads | 
					
						
						|  | if self.n_kv_heads == n_kv_heads_should_be: | 
					
						
						|  | return n_kv_heads_should_be | 
					
						
						|  | else: | 
					
						
						|  | raise OLMoConfigurationError( | 
					
						
						|  | "You can't set `multi_query_attention` and `n_kv_heads` at the same time." | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class OptimizerType(StrEnum): | 
					
						
						|  | lionw = "lionw" | 
					
						
						|  | adamw = "adamw" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class OptimizerConfig(BaseConfig): | 
					
						
						|  | name: OptimizerType = OptimizerType.lionw | 
					
						
						|  | learning_rate: float = 1.0e-4 | 
					
						
						|  | weight_decay: float = 0.01 | 
					
						
						|  | betas: Tuple[float, float] = (0.9, 0.95) | 
					
						
						|  |  | 
					
						
						|  | no_decay_norm_and_bias: Optional[bool] = None | 
					
						
						|  | """ | 
					
						
						|  | Deprecated. Use ``decay_norm_and_bias`` and ``decay_embeddings`` instead. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | decay_norm_and_bias: bool = False | 
					
						
						|  | decay_embeddings: bool = False | 
					
						
						|  | metrics_log_interval: Optional[int] = None | 
					
						
						|  | """ | 
					
						
						|  | The interval with which to collect and log detailed parameter-specific metrics. | 
					
						
						|  | This only applies when logging to W&B, since these metrics won't be logged to the console. | 
					
						
						|  | If not set, defaults to the wandb `log_interval`. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | def __post_init__(self): | 
					
						
						|  | self.betas = tuple(self.betas) | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def update_legacy_settings(cls, config: D) -> D: | 
					
						
						|  | new_config = config.copy() | 
					
						
						|  | if om.is_dict(new_config): | 
					
						
						|  | assert isinstance(new_config, DictConfig) | 
					
						
						|  |  | 
					
						
						|  | if hasattr(new_config, "name") and new_config.name == "decoupled_lionw": | 
					
						
						|  | new_config.name = "lionw" | 
					
						
						|  | if hasattr(new_config, "eps"): | 
					
						
						|  | del new_config.eps | 
					
						
						|  |  | 
					
						
						|  | return new_config | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SchedulerType(StrEnum): | 
					
						
						|  | cosine_with_warmup = "cosine_with_warmup" | 
					
						
						|  | linear_with_warmup = "linear_with_warmup" | 
					
						
						|  | inverse_sqrt_with_warmup = "inverse_sqrt_with_warmup" | 
					
						
						|  | max_scheduler = "max_scheduler" | 
					
						
						|  | constant = "constant" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class SchedulerUnits(StrEnum): | 
					
						
						|  | steps = "steps" | 
					
						
						|  | tokens = "tokens" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class SchedulerConfig(BaseConfig): | 
					
						
						|  | name: SchedulerType = SchedulerType.cosine_with_warmup | 
					
						
						|  | units: SchedulerUnits = SchedulerUnits.steps | 
					
						
						|  | t_warmup: Union[int, float] = 100 | 
					
						
						|  | t_max: Optional[Union[int, float]] = None | 
					
						
						|  | alpha_f: float = 0.1 | 
					
						
						|  |  | 
					
						
						|  | grad_clip_warmup_steps: Optional[Union[int, float]] = None | 
					
						
						|  | """ | 
					
						
						|  | The warmup period for which the max grad norm (or norm ratio) will be set to its | 
					
						
						|  | warmup value of `max_grad_norm * grad_clip_warmup_factor`. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | grad_clip_warmup_factor: Optional[float] = None | 
					
						
						|  | """ | 
					
						
						|  | The ratio of the max allowed gradient norm (or norm ratio) for clipping during the warmup period | 
					
						
						|  | vs after the warmup period. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class PaddingDirection(StrEnum): | 
					
						
						|  | right = "right" | 
					
						
						|  | left = "left" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class DataConfig(BaseConfig): | 
					
						
						|  | paths: Optional[List[str]] = None | 
					
						
						|  | datasets: Optional[Dict[str, List[str]]] = None | 
					
						
						|  | label_mask_paths: Optional[List[str]] = None | 
					
						
						|  | pad_direction: PaddingDirection = PaddingDirection.right | 
					
						
						|  | generate_attention_mask: bool = False | 
					
						
						|  | num_workers: int = 0 | 
					
						
						|  | drop_last: bool = False | 
					
						
						|  | pin_memory: bool = False | 
					
						
						|  | prefetch_factor: Optional[int] = None | 
					
						
						|  | persistent_workers: bool = False | 
					
						
						|  | timeout: int = 0 | 
					
						
						|  | seed: Optional[int] = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class EvaluatorType(StrEnum): | 
					
						
						|  | downstream = "downstream" | 
					
						
						|  | lm = "lm" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class EvaluatorConfig(BaseConfig): | 
					
						
						|  | label: str | 
					
						
						|  | type: EvaluatorType = EvaluatorType.lm | 
					
						
						|  | data: DataConfig = field(default_factory=DataConfig) | 
					
						
						|  | device_eval_batch_size: Optional[int] = None | 
					
						
						|  | subset_num_batches: Optional[int] = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TruncationDirection(StrEnum): | 
					
						
						|  | right = "right" | 
					
						
						|  | left = "left" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class TokenizerConfig(BaseConfig): | 
					
						
						|  | identifier: str = "gpt2" | 
					
						
						|  | truncate_direction: TruncationDirection = TruncationDirection.right | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class WandbConfig(BaseConfig): | 
					
						
						|  | project: Optional[str] = None | 
					
						
						|  | entity: Optional[str] = "ai2-llm" | 
					
						
						|  | group: Optional[str] = None | 
					
						
						|  | name: Optional[str] = None | 
					
						
						|  | tags: Optional[List[str]] = field(default_factory=lambda: ["watching"]) | 
					
						
						|  | log_artifacts: bool = False | 
					
						
						|  | rank_zero_only: bool = True | 
					
						
						|  | log_interval: int = 1 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class SpeedMonitorConfig(BaseConfig): | 
					
						
						|  | window_size: int = 100 | 
					
						
						|  | gpu_flops_available: Optional[Union[float, int]] = None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class CompilerConfig(BaseConfig): | 
					
						
						|  | mode: Optional[str] = None | 
					
						
						|  | """ | 
					
						
						|  | The mode to compile the model in. At the moment this can be "default", | 
					
						
						|  | "reduce-overhead" (useful for smaller models/batches), or "max-autotune" | 
					
						
						|  | (the fastest for larger models, but takes a long time to compile). | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | fullgraph: bool = False | 
					
						
						|  | """ | 
					
						
						|  | Whether it is OK to break model into several subgraphs when compiling. | 
					
						
						|  | Note that this is not compatible with FSDP. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | backend: str = "inductor" | 
					
						
						|  | """ | 
					
						
						|  | The backend to use. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FSDPWrapStrategy(StrEnum): | 
					
						
						|  | by_block = "by_block" | 
					
						
						|  | """ | 
					
						
						|  | Wrap each OLMo block with its own FSDP instance. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | by_block_and_size = "by_block_and_size" | 
					
						
						|  | """ | 
					
						
						|  | Like 'by_block' but `wte` and `ff_out` will be wrapped separately as well. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | by_block_group = "by_block_group" | 
					
						
						|  | """ | 
					
						
						|  | Wrap each block group together into its own FSDP instance. | 
					
						
						|  | This requires :attr:`~ModelConfig.block_group_size` to be bigger than 1. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | by_block_group_and_size = "by_block_group_and_size" | 
					
						
						|  | """ | 
					
						
						|  | Like 'by_block_group' but `wte` and `ff_out` will be wrapped separately as well. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | size_based = "size_based" | 
					
						
						|  | """ | 
					
						
						|  | Used PyTorch's default size-based auto wrap policy. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | one_in_two = "one_in_two" | 
					
						
						|  | one_in_three = "one_in_three" | 
					
						
						|  | one_in_four = "one_in_four" | 
					
						
						|  | one_in_five = "one_in_five" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class FSDPPrecision(StrEnum): | 
					
						
						|  | pure = "pure" | 
					
						
						|  | """ | 
					
						
						|  | Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, ``reduce_dtype``, | 
					
						
						|  | and ``buffer_dtype`` all set to the autocast precision data type. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | mixed = "mixed" | 
					
						
						|  | """ | 
					
						
						|  | Equivalent to :class:`torch.distributed.fsdp.MixedPrecision` with ``param_dtype``, and ``buffer_dtype`` | 
					
						
						|  | set to the autocast precision data type, while ``reduce_dtype`` is set to fp32. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class FSDPConfig(BaseConfig): | 
					
						
						|  | use_orig_params: bool = True | 
					
						
						|  | """ | 
					
						
						|  | This must be ``True`` if using ``compile`` or you want to track the parameter norm during training. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD | 
					
						
						|  |  | 
					
						
						|  | wrapping_strategy: Optional[FSDPWrapStrategy] = None | 
					
						
						|  | """ | 
					
						
						|  | The wrapping strategy to use. If ``None``, the default, the model is wrapped with a single top-level | 
					
						
						|  | FSDP instance. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | precision: FSDPPrecision = FSDPPrecision.pure | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class CheckpointType(StrEnum): | 
					
						
						|  | sharded = "sharded" | 
					
						
						|  | unsharded = "unsharded" | 
					
						
						|  | sharded_ephemeral = "sharded_ephemeral" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ShardedCheckpointerType(StrEnum): | 
					
						
						|  | torch_new = "torch_new" | 
					
						
						|  | torch_legacy = "torch_legacy" | 
					
						
						|  | local = "local" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class ActivationCheckpointingStrategy(StrEnum): | 
					
						
						|  | whole_layer = "whole_layer" | 
					
						
						|  | """ | 
					
						
						|  | Checkpoint every transformer layer. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | one_in_two = "one_in_two" | 
					
						
						|  | """ | 
					
						
						|  | Checkpoint one in two transformer layers. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | one_in_three = "one_in_three" | 
					
						
						|  | """ | 
					
						
						|  | Checkpoint one in three transformer layers. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | one_in_four = "one_in_four" | 
					
						
						|  | """ | 
					
						
						|  | Checkpoint one in four transformer layers. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | two_in_three = "two_in_three" | 
					
						
						|  | """ | 
					
						
						|  | Checkpoint two out of every three transformer layers. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | three_in_four = "three_in_four" | 
					
						
						|  | """ | 
					
						
						|  | Checkpoint three out of four of every transformer layers. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | fine_grained = "fine_grained" | 
					
						
						|  | """ | 
					
						
						|  | Focus checkpointing on where it is cheap to recompute and saves most memory. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @dataclass | 
					
						
						|  | class TrainConfig(BaseConfig): | 
					
						
						|  | """ | 
					
						
						|  | OLMo training configuration. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | run_name: Optional[str] = None | 
					
						
						|  | """ | 
					
						
						|  | The name of the run. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | seed: int = 6198 | 
					
						
						|  | """ | 
					
						
						|  | Used to seed all initial RNG states. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | epoch: Optional[int] = None | 
					
						
						|  | """ | 
					
						
						|  | Increment this when starting a new epoch. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | dry_run: bool = False | 
					
						
						|  | """ | 
					
						
						|  | If ``True``, don't actually train. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | model: ModelConfig = field(default_factory=ModelConfig) | 
					
						
						|  | """ | 
					
						
						|  | OLMo Model configuration. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | optimizer: OptimizerConfig = field(default_factory=OptimizerConfig) | 
					
						
						|  | """ | 
					
						
						|  | Optimizer configuration. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | scheduler: SchedulerConfig = field(default_factory=SchedulerConfig) | 
					
						
						|  | """ | 
					
						
						|  | Learning rate scheduler configuration. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | data: DataConfig = field(default_factory=DataConfig) | 
					
						
						|  | """ | 
					
						
						|  | Training data configuration. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | restore_dataloader: bool = True | 
					
						
						|  | """ | 
					
						
						|  | When restarting, restore the data loader to where it left off. | 
					
						
						|  | If you restarting in order to train on a different dataset, set this to ``False``. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | fast_forward_batches: Optional[int] = None | 
					
						
						|  | """ | 
					
						
						|  | When restarting, use this to fast-forward the dataloader beyond the last checkpoint. | 
					
						
						|  | This can be useful when restarting due to a loss spike in order to skip the data that | 
					
						
						|  | corresponded to the spike. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | evaluators: List[EvaluatorConfig] = field(default_factory=list) | 
					
						
						|  | """ | 
					
						
						|  | Evaluation configurations. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | eval_interval: int = 1000 | 
					
						
						|  | """ | 
					
						
						|  | How often (in terms of batches) to run evaluations. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | tokenizer: TokenizerConfig = field(default_factory=TokenizerConfig) | 
					
						
						|  | """ | 
					
						
						|  | Tokenizer configuration. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | save_folder: str = "./" | 
					
						
						|  | """ | 
					
						
						|  | The directory to save checkpoints to. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | remote_save_folder: Optional[str] = None | 
					
						
						|  | """ | 
					
						
						|  | A folder in a cloud bucket to upload saved checkpoints to. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | canceled_check_interval: int = 50 | 
					
						
						|  | """ | 
					
						
						|  | How often (in batches) to check if the run has been canceled or reached its time limit. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | save_interval: int = 1000 | 
					
						
						|  | """ | 
					
						
						|  | How often (in terms of steps) to save sharded training state checkpoints. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | save_interval_unsharded: Optional[int] = None | 
					
						
						|  | """ | 
					
						
						|  | How often (if at all) to save unsharded training state checkpoint. | 
					
						
						|  | For large models it can be costly to save these, so it usually makes sense to save | 
					
						
						|  | these less often than regular (sharded) training checkpoints. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | save_interval_ephemeral: Optional[int] = None | 
					
						
						|  | """ | 
					
						
						|  | How often (if at all) to save ephemeral sharded checkpoints. These checkpoints are the same | 
					
						
						|  | as those saved every `save_interval` except that at most only the most recent one of these is kept. | 
					
						
						|  | This is useful when you want to checkpoint often for restarts in case of failures, but don't | 
					
						
						|  | want to keep the majority of these checkpoints. | 
					
						
						|  |  | 
					
						
						|  | For example, suppose you want to keep your checkpoints at every 1000 steps, but you also want to save | 
					
						
						|  | a temporary checkpoint every 100 steps in case your job fails. In that case you would | 
					
						
						|  | set `save_interval=1000` and `save_interval_ephemeral=100`. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | save_num_checkpoints_to_keep: int = -1 | 
					
						
						|  | """ | 
					
						
						|  | How many sharded checkpoints to keep. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | save_num_unsharded_checkpoints_to_keep: int = -1 | 
					
						
						|  | """ | 
					
						
						|  | How many unsharded checkpoints to keep. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | save_overwrite: bool = False | 
					
						
						|  | """ | 
					
						
						|  | If ``True``, overwrite any conflicting checkpoint files. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | force_save_unsharded: bool = False | 
					
						
						|  | """ | 
					
						
						|  | Save an unsharded checkpoint before training (even during a dry run). | 
					
						
						|  | Use this option with `--load-path={PATH}` and `--dry_run` to convert a sharded | 
					
						
						|  | checkpoint into an unsharded checkpoint. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | no_pre_train_checkpoint: bool = False | 
					
						
						|  | """ | 
					
						
						|  | Skip saving pre-train checkpoint. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | load_path: Optional[str] = None | 
					
						
						|  | """ | 
					
						
						|  | The path to a training checkpoint to restore/resume from. | 
					
						
						|  |  | 
					
						
						|  | Note that you can make use of the "path.last_checkpoint" Omegaconfig YAML resolver here, which takes | 
					
						
						|  | a local or remote directory and resolves to the latest checkpoint (sharded or unsharded) in that directory. | 
					
						
						|  | For example, | 
					
						
						|  |  | 
					
						
						|  | ```bash | 
					
						
						|  | --load_path='${path.last_checkpoint:s3://ai2-llm/checkpoints/7b/v1_5-mix-run-001}' | 
					
						
						|  | ``` | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | load_path_sharded_checkpointer: Optional[ShardedCheckpointerType] = None | 
					
						
						|  | """ | 
					
						
						|  | The sharded checkpointer type to use to load the initial checkpoint from ``load_path``. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | reset_optimizer_state: bool = False | 
					
						
						|  | """ | 
					
						
						|  | When this is set, we restore the model from a checkpoint (if given), but we leave the optimizer uninitialized. | 
					
						
						|  | We also set a new learning rate schedule that does a new warmup, such that it intercepts the original learning | 
					
						
						|  | curve (according to the current learning rate schedule settings), and continues from there. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | reset_trainer_state: bool = False | 
					
						
						|  | """ | 
					
						
						|  | When this is set we don't restore the trainer state from a checkpoint. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | sharded_checkpointer: ShardedCheckpointerType = ShardedCheckpointerType.torch_legacy | 
					
						
						|  | """ | 
					
						
						|  | The name of the sharded checkpointer to use to save (sharded) checkpoints throughout training. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | new_style_checkpoints: Optional[bool] = None | 
					
						
						|  | """ | 
					
						
						|  | Deprecated. Use ``sharded_checkpointer`` instead. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | max_duration: Union[int, str] = 10000 | 
					
						
						|  | """ | 
					
						
						|  | How long to train for. | 
					
						
						|  |  | 
					
						
						|  | If specified without a unit (the default), the units are assumed to be steps. | 
					
						
						|  | You can also specify this in terms of tokens, for example: `max_duration="2e12T"` means train until | 
					
						
						|  | 2 trillion tokens. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | global_train_batch_size: int = 512 | 
					
						
						|  | """ | 
					
						
						|  | The effective global batch size. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | device_train_batch_size: Optional[int] = None | 
					
						
						|  | """ | 
					
						
						|  | Don't set this manually. This will be set to ``global_train_batch_size // world_size``. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | device_train_microbatch_size: int = 16 | 
					
						
						|  | """ | 
					
						
						|  | The number of instances passed to the model in a single forward-backward pass. You should set | 
					
						
						|  | this as large as you can based on available GPU memory. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | device_eval_batch_size: int = 16 | 
					
						
						|  | """ | 
					
						
						|  | The number of evaluation instances passed to the model in a single forward pass on each device. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | eval_subset_num_batches: int = -1 | 
					
						
						|  | """ | 
					
						
						|  | The number of batches to use for downstream evaluation from each dataset. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | eval_on_load: bool = False | 
					
						
						|  | """ | 
					
						
						|  | When resuming from a checkpoint, run the evaluation loop right away. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | device_train_grad_accum: Optional[int] = None | 
					
						
						|  | """ | 
					
						
						|  | Don't set this manually. This will be set to ``device_train_batch_size // device_train_microbatch_size``. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | max_grad_norm: Optional[float] = None | 
					
						
						|  | """ | 
					
						
						|  | Clip gradient norms to this value if set. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | max_grad_norm_ratio: Optional[float] = None | 
					
						
						|  | """ | 
					
						
						|  | If set, gradient norms will be clipped to `max_grad_norm_ratio * exp_avg(norm(grad))`. | 
					
						
						|  | This takes priority over `max_grad_norm` when set. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | precision: Optional[str] = None | 
					
						
						|  | """ | 
					
						
						|  | Precision to train with (e.g. "amp_bf16", "amp_fp16", or "fp32"). | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | wandb: Optional[WandbConfig] = None | 
					
						
						|  | """ | 
					
						
						|  | Weights & Biases configuration. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | speed_monitor: SpeedMonitorConfig = field(default_factory=SpeedMonitorConfig) | 
					
						
						|  | """ | 
					
						
						|  | Speed monitor configuration. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | console_log_interval: int = 1 | 
					
						
						|  | """ | 
					
						
						|  | How often to log to the console. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | compile: Optional[CompilerConfig] = None | 
					
						
						|  | """ | 
					
						
						|  | Settings for compiling the model with ``torch.compile()``. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | fsdp: FSDPConfig = field(default_factory=FSDPConfig) | 
					
						
						|  | """ | 
					
						
						|  | Fully sharded data parallel settings. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | softmax_auxiliary_loss: bool = False | 
					
						
						|  | """ | 
					
						
						|  | If ``True``, we add the auxiliary loss function from PaLM that encourages the softmax | 
					
						
						|  | normalizing term to be close to 0. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | time_limit: Optional[float] = 60 * 60 * 47.5 | 
					
						
						|  | """ | 
					
						
						|  | The maximum amount of time to train for before saving a checkpoint and ending early. | 
					
						
						|  | On LUMI we have 48 hours max per job, so we default to just under 48 hours to give us time | 
					
						
						|  | to write out a final checkpoint. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | extra_steps_after_cancel: int = 10 | 
					
						
						|  | """ | 
					
						
						|  | Under certain conditions when a run is canceled we train for a few extra steps after saving | 
					
						
						|  | the final checkpoint so that when the run is restarted from the latest checkpoint we have some | 
					
						
						|  | overlap in metrics. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | early_stopping_factor: Optional[float] = None | 
					
						
						|  |  | 
					
						
						|  | save_data_indices: bool = True | 
					
						
						|  | """ | 
					
						
						|  | Save training data indices from each batch for each worker. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | python_profiling: bool = False | 
					
						
						|  | """ | 
					
						
						|  | Whether to run the Python profiler on batches 6, 7, and 8. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | torch_profiling: bool = False | 
					
						
						|  | """ | 
					
						
						|  | Whether to run the PyTorch profiler on batches 6, 7, and 8. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | stop_at: Optional[int] = None | 
					
						
						|  | """ | 
					
						
						|  | Stop at a specific step. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | stop_after: Optional[int] = None | 
					
						
						|  | """ | 
					
						
						|  | Stop after a specific number of steps. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | activation_checkpointing: Optional[ActivationCheckpointingStrategy] = None | 
					
						
						|  | """ | 
					
						
						|  | The activation checkpointing strategy to use. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | fused_loss: Optional[bool] = None | 
					
						
						|  | """ | 
					
						
						|  | Whether to use the fused CE loss function from `flash-attn`. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def autocast_precision(self) -> torch.dtype: | 
					
						
						|  | if self.precision == "amp_bf16": | 
					
						
						|  | return torch.bfloat16 | 
					
						
						|  | elif self.precision == "amp_fp16": | 
					
						
						|  | return torch.float16 | 
					
						
						|  | elif self.precision == "fp32": | 
					
						
						|  | return torch.float32 | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError(f"Unexpected precision type '{self.precision}'") | 
					
						
						|  |  | 
					
						
						|  | @property | 
					
						
						|  | def fsdp_precision(self) -> MixedPrecision: | 
					
						
						|  | if self.fsdp.precision == FSDPPrecision.pure: | 
					
						
						|  | return MixedPrecision( | 
					
						
						|  | param_dtype=self.autocast_precision, | 
					
						
						|  | reduce_dtype=self.autocast_precision, | 
					
						
						|  | buffer_dtype=self.autocast_precision, | 
					
						
						|  | ) | 
					
						
						|  | elif self.fsdp.precision == FSDPPrecision.mixed: | 
					
						
						|  | return MixedPrecision( | 
					
						
						|  | param_dtype=self.autocast_precision, | 
					
						
						|  | reduce_dtype=torch.float32, | 
					
						
						|  | buffer_dtype=self.autocast_precision, | 
					
						
						|  | ) | 
					
						
						|  | else: | 
					
						
						|  | raise NotImplementedError(f"{self.fsdp.precision}") | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def update_legacy_settings(cls, config: D) -> D: | 
					
						
						|  | new_config = config.copy() | 
					
						
						|  | if om.is_dict(new_config): | 
					
						
						|  | assert isinstance(new_config, DictConfig) | 
					
						
						|  |  | 
					
						
						|  | if hasattr(new_config, "activation_checkpointing"): | 
					
						
						|  | if new_config.activation_checkpointing is False: | 
					
						
						|  | new_config.activation_checkpointing = None | 
					
						
						|  | if new_config.activation_checkpointing is True: | 
					
						
						|  | new_config.activation_checkpointing = ActivationCheckpointingStrategy.whole_layer | 
					
						
						|  |  | 
					
						
						|  | if hasattr(new_config, "optimizer"): | 
					
						
						|  | new_config.optimizer = OptimizerConfig.update_legacy_settings(new_config.optimizer) | 
					
						
						|  |  | 
					
						
						|  | return new_config | 
					
						
						|  |  |