When loading a pretrained model in PyTorch, the usual workflow looks like this:
import torch
my_model = ModelClass(...)
state_dict = torch.load(checkpoint_file)
my_model.load_state_dict(state_dict)In plain English, those steps are:
While this works very well for regularly sized models, this workflow has some clear limitation when we deal with a huge model: in step 1, we load a full version of the model in RAM, and spend some time randomly initializing the weights (which will be discarded in step 3). In step 2, we load another full version of the model in RAM, with the pretrained weights. If you’re loading a model with 6 billions parameters, this needs you will need 24GB of RAM for each copy of the model, so 48GB in total (half of it to load the model in FP16).
This API is quite new and still in its experimental stage. While we strive to provide a stable API, it’s possible some small parts of the public API will change in the future.
The first tool Accelerate introduces to help with big models is a context manager init_empty_weights() that helps you initialize a model without using any RAM, so that step 1 can be done on models of any size. Here is how it works:
from accelerate import init_empty_weights
with init_empty_weights():
my_model = ModelClass(...)For instance:
with init_empty_weights():
model = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])initializes an empty model with a bit more than 100B parameters. Behind the scenes, this relies on the meta device introduced in PyTorch 1.9. During the initialization under the context manager, each time a parameter is created, it is instantly moved on that device.
You can’t move a model initialized like this on CPU or another device directly, since it doesn’t have any data. It’s also very likely that a forward pass with that empty model will fail, as not all operations are supported on the meta device.
It’s possible your model is so big that even a single copy won’t fit in RAM. That doesn’t mean it can’t be loaded: if you have one or several GPUs, this is more memory available to store your model. In this case, it’s better if your checkpoint is split in several smaller files that we call checkpoint shards.
Accelerate will handle sharded checkpoints as long as you follow the following format: your checkpoint should be in a folder, with several files containing the partial state dicts, and there should be an index in the JSON format that contains a dictionary mapping parameter names to the file containing their weights. For instance we could have a folder containing:
first_state_dict.bin index.json second_state_dict.bin
with index.json being the following file:
{
"linear1.weight": "first_state_dict.bin",
"linear1.bias": "first_state_dict.bin",
"linear2.weight": "second_state_dict.bin",
"linear2.bias": "second_state_dict.bin"
}and first_state_dict.bin containing the weights for "linear1.weight" and "linear1.bias", second_state_dict.bin the ones for "linear2.weight" and "linear2.bias"
The second tool Accelerate introduces is a function load_checkpoint_and_dispatch(), that will allow you to load a checkpoint inside your empty model. This supports full checkpoints (a single file containing the whole state dict) as well as sharded checkpoints. It will also automatically dispatch those weights across the devices you have available (GPUs, CPU RAM), so if you are loading a sharded checkpoint, the maximum RAM usage will be the size of the biggest shard.
Here is how we can use this to load the GPT-J-6B model. You clone the sharded version of this model with:
git clone https://huggingface.co/sgugger/sharded-gpt-j-6B
cd sharded-gpt-j-6B
git-lfs install
git pullthen we can initialize the model with
from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
checkpoint = "EleutherAI/gpt-j-6B"
config = AutoConfig.from_pretrained(checkpoint)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)and load the checkpoint we just downloaded with:
from accelerate import load_checkpoint_and_dispatch
model = load_checkpoint_and_dispatch(
model, "sharded-gpt-j-6B", device_map="auto", no_split_module_classes=["GPTJBlock"]
)By passing device_map="auto", we tell Accelerate to determine automatically where to put each layer of the model depending on the available resources:
no_split_module_classes=["GPTJBlock"] indicates that the modules that are GPTJBlock should not be split on different devices. You should set here all blocks that include a residual connection of some kind.
You can see the device_map that Accelerate picked by accessing the hf_device_map attribute of your model:
model.hf_device_map
{'transformer.wte': 0,
'transformer.drop': 0,
'transformer.h.0': 0,
'transformer.h.1': 0,
'transformer.h.2': 0,
'transformer.h.3': 0,
'transformer.h.4': 0,
'transformer.h.5': 0,
'transformer.h.6': 0,
'transformer.h.7': 0,
'transformer.h.8': 0,
'transformer.h.9': 0,
'transformer.h.10': 0,
'transformer.h.11': 0,
'transformer.h.12': 0,
'transformer.h.13': 0,
'transformer.h.14': 0,
'transformer.h.15': 0,
'transformer.h.16': 0,
'transformer.h.17': 0,
'transformer.h.18': 0,
'transformer.h.19': 0,
'transformer.h.20': 0,
'transformer.h.21': 0,
'transformer.h.22': 0,
'transformer.h.23': 0,
'transformer.h.24': 1,
'transformer.h.25': 1,
'transformer.h.26': 1,
'transformer.h.27': 1,
'transformer.ln_f': 1,
'lm_head': 1}You can also design your device_map yourself, if you prefer to explicitly decide where each layer should be. In this case, the command above becomes:
model = load_checkpoint_and_dispatch(model, "sharded-gpt-j-6B", device_map=my_device_map)Now that we have done this, our model lies across several devices, and maybe the hard drive. But it can still be used as a regular PyTorch model:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer("Hello, my name is", return_tensors="pt")
inputs = inputs.to(0)
output = model.generate(inputs["input_ids"])
tokenizer.decode(output[0].tolist())Behind the scenes, Accelerate added hooks to the model, so that:
This way, you model can run for inference even if it doesn’t fit on one of the GPUs or the CPU RAM!
This only supports inference of your model, not training. Most of the computation happens behind torch.no_grad() context managers to avoid spending some GPU memory with intermediate activations.
We are aware of the current limitations in the API:
device_map="auto" in load_checkpoint_and_dispatch()) tries to maximize GPU and CPU RAM it sees available when you execute it. While PyTorch is very good at managing GPU RAM efficiently (and giving it back when not needed), it’s not entirely true with Python and CPU RAM. Therefore, an automatically computed device map might be too intense on the CPU. Move a few modules to the disk device if you get crashes due to lack of RAM.device_map="auto" in load_checkpoint_and_dispatch()) attributes devices sequentially (to avoid moving things back and forth) so if your first layer is bigger than the size of the GPU you have, it will end up with everything on the CPU/Disk.( model: Module execution_device: typing.Optional[torch.device] = None offload_buffers: bool = False state_dict: typing.Union[typing.Dict[str, torch.Tensor], NoneType] = None preload_module_classes: typing.Optional[typing.List[str]] = None )
Parameters
torch.nn.Module) —
The model to offload.
torch.device, optional) —
The device on which the forward pass of the model will be executed (should be a GPU). Will default to the
model first parameter device.
bool, optional, defaults to False) —
Whether or not to offload the buffers with the model parameters.
Dict[str, torch.Tensor], optional) —
The state dict of the model that will be kept on CPU.
List[str], optional) —
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
of the forward. This should only be used for classes that have submodules which are registered but not
called directly during the forward, for instance if a dense linear layer is registered, but at forward,
dense.weight and dense.bias are used in some operations instead of calling dense directly.
Activates full CPU offload for a model. As a result, all parameters of the model will be offloaded and only one copy of the state dict of the model will be kept. During the forward pass, parameters will be extracted from that state dict and put on the execution device passed as they are needed, then offloaded again.
( model: Module offload_dir: typing.Union[str, os.PathLike] execution_device: typing.Optional[torch.device] = None offload_buffers: bool = False preload_module_classes: typing.Optional[typing.List[str]] = None )
Parameters
torch.nn.Module) — The model to offload.
str or os.PathLike) —
The folder in which to offload the model weights (or where the model weights are already offloaded).
torch.device, optional) —
The device on which the forward pass of the model will be executed (should be a GPU). Will default to the
model’s first parameter device.
bool, optional, defaults to False) —
Whether or not to offload the buffers with the model parameters.
List[str], optional) —
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
of the forward. This should only be used for classes that have submodules which are registered but not
called directly during the forward, for instance if a dense linear layer is registered, but at forward,
dense.weight and dense.bias are used in some operations instead of calling dense directly.
Activates full disk offload for a model. As a result, all parameters of the model will be offloaded as memory-mapped array in a given folder. During the forward pass, parameters will be accessed from that folder and put on the execution device passed as they are needed, then offloaded again.
( model: Module device_map: typing.Dict[str, typing.Union[int, str, torch.device]] main_device: typing.Optional[torch.device] = None state_dict: typing.Union[typing.Dict[str, torch.Tensor], NoneType] = None offload_dir: typing.Union[str, os.PathLike] = None offload_buffers: bool = False preload_module_classes: typing.Optional[typing.List[str]] = None )
Parameters
torch.nn.Module) —
The model to dispatch.
Dict[str, Union[str, int, torch.device]]) —
A dictionary mapping module names in the models state_dict to the device they should go to. Note that
"disk" is accepted even if it’s not a proper value for torch.device.
str, int or torch.device, optional) —
The main execution device. Will default to the first device in the device_map different from "cpu" or
"disk".
Dict[str, torch.Tensor], optional) —
The state dict of the part of the model that will be kept on CPU.
str or os.PathLike) —
The folder in which to offload the model weights (or where the model weights are already offloaded).
bool, optional, defaults to False) —
Whether or not to offload the buffers with the model parameters.
List[str], optional) —
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
of the forward. This should only be used for classes that have submodules which are registered but not
called directly during the forward, for instance if a dense linear layer is registered, but at forward,
dense.weight and dense.bias are used in some operations instead of calling dense directly.
Dispatches a model according to a given device map. Layers of the model might be spread across GPUs, offloaded on the CPU or even the disk.
( model: Module max_memory: typing.Union[typing.Dict[typing.Union[int, str], typing.Union[int, str]], NoneType] = None no_split_module_classes: typing.Optional[typing.List[str]] = None dtype: typing.Union[str, torch.dtype, NoneType] = None )
Parameters
torch.nn.Module) — The model to analyze.
Dict, optional) —
A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.
List[str], optional) —
A list of layer class names that should never be split across device (for instance any layer that has a
residual connection).
str or torch.dtype, optional) —
If provided, the weights will be converted to that type when loaded.
Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk, such that:
All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the
meta device (as it would if initialized within the init_empty_weights context manager).
( include_buffers: bool = False )
A context manager under which models are initialized with all parameters on the meta device, therefore creating an empty model. Useful when just initializing the model would blow the available RAM.
Example:
import torch.nn as nn
from accelerate import init_empty_weights
<h1 id="initialize-a-model-with-100-billions-parameters-in-no-time-and-without-using-any-ram">Initialize a model with 100 billions parameters in no time and without using any RAM.</h1>
with init_empty_weights():
tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])Any model created under this context manager has no weights. As such you can’t do something like
model.to(some_device) with it. To load weights inside your empty model, see load_checkpoint_and_dispatch().
( model: Module checkpoint: typing.Union[str, os.PathLike] device_map: typing.Union[str, typing.Dict[str, typing.Union[int, str, torch.device]], NoneType] = None max_memory: typing.Union[typing.Dict[typing.Union[int, str], typing.Union[int, str]], NoneType] = None no_split_module_classes: typing.Optional[typing.List[str]] = None offload_folder: typing.Union[str, os.PathLike, NoneType] = None offload_buffers: bool = False dtype: typing.Union[str, torch.dtype, NoneType] = None offload_state_dict: bool = False preload_module_classes: typing.Optional[typing.List[str]] = None )
Parameters
torch.nn.Module) — The model in which we want to load a checkpoint.
str or os.PathLike) —
The folder checkpoint to load. It can be:
.json file containing the index to a sharded checkpoint.index.json file and the shards of a checkpoint.Dict[str, Union[int, str, torch.device]], optional) —
A map that specifies where each submodule should go. It doesn’t need to be refined to each parameter/buffer
name, once a given module name is inside, every submodule of it will be sent to the same device.
To have Accelerate compute the most optimized device_map automatically, set device_map="auto".
Dict, optional) —
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each GPU
and the available CPU RAM if unset.
List[str], optional) —
A list of layer class names that should never be split across device (for instance any layer that has a
residual connection).
str or os.PathLike, optional) —
If the device_map contains any value "disk", the folder where we will offload weights.
bool, optional, defaults to False) —
In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as
well as the parameters.
str or torch.dtype, optional) —
If provided, the weights will be converted to that type when loaded.
bool, optional, defaults to False) —
If True, will temporarily offload the CPU state dict on the hard drive to avoig getting out of CPU RAM if
the weight of the CPU state dict + the biggest shard does not fit.
List[str], optional) —
A list of classes whose instances should load all their weights (even in the submodules) at the beginning
of the forward. This should only be used for classes that have submodules which are registered but not
called directly during the forward, for instance if a dense linear layer is registered, but at forward,
dense.weight and dense.bias are used in some operations instead of calling dense directly.
Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are loaded and adds the various hooks that will make this model run properly (even if split across devices).
( model: Module checkpoint: typing.Union[str, os.PathLike] device_map: typing.Union[typing.Dict[str, typing.Union[int, str, torch.device]], NoneType] = None offload_folder: typing.Union[str, os.PathLike, NoneType] = None dtype: typing.Union[str, torch.dtype, NoneType] = None offload_state_dict: bool = False )
Parameters
torch.nn.Module) — The model in which we want to load a checkpoint.
str or os.PathLike) —
The folder checkpoint to load. It can be:.json file containing the index to a sharded checkpoint.index.json file and the shards of a checkpoint.Dict[str, Union[int, str, torch.device]], optional) —
A map that specifies where each submodule should go. It doesn’t need to be refined to each parameter/buffer
name, once a given module name is inside, every submodule of it will be sent to the same device.
str or os.PathLike, optional) —
If the device_map contains any value "disk", the folder where we will offload weights.
str or torch.dtype, optional) —
If provided, the weights will be converted to that type when loaded.
bool, optional, defaults to False) —
If True, will temporarily offload the CPU state dict on the hard drive to avoig getting out of CPU RAM if
the weight of the CPU state dict + the biggest shard does not fit.
Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are loaded.
Once loaded across devices, you still need to call dispatch_model() on your model to make it able to run. To group the checkpoint loading and dispatch in one single call, use load_checkpoint_and_dispatch().