`_.
The way existing FSDP works with autograd:
* Existing FSDP all-gathers the ``flat_param``, which is the autograd leaf.
* It calls ``torch.split`` to get 1D views into the ``flat_param`` corresponding to its constituent original parameters.
* It calls ``torch.view`` on each 1D split to view back to ND.
* This means that in ``backward``, we end up with ``ViewBackward`` (ND -> 1D) and ``SplitWithSizesBackward`` (which is a concat). In particular, each individual gradient is computed as a separate allocation, and an explicit concat happens to construct the reduce-scatter input buffer. This implies actually a 2x buffer size for reduce-scatter at that peak memory point.
In summary, for ``backward``, it is about 2x buffer size for reduce-scatter plus any ``recordStream`` effects.
Second, let's discuss the additional buffers:
Once the sharded parameters are gathered from all ranks, they require an additional buffer of `total_transformer_block_params_in_B*dtype_bytes` for the full parameters - so continuing the example from earlier if each transformer block is 1.6B parameters and the parameters are in fp32, then it'd be `1.6*4=6.4GB` buffer.
And there is a need for 2 of those buffers, since there is one currently being used and another being prefetched.
To summarize, we have:
1. 2 times communication buffers of ``total_transformer_block_params_in_B*dtype_bytes/num_gpus``
2. 2 times unsharded transformer block parameters buffer ````total_transformer_block_params_in_B*dtype_bytes``
or if you have been following the example:
1. ``2*1.6*4/8=1.6GB``
2. ``2**1.6*4=12.8GB``
and the total of ``14.4GB``.
Now let's briefly discuss what happens to the embeddings as we have left those out from the calculations:
Given the rule we discussed that you included in the note starting with "the communication buffer
size is determined as follows", we can analyze as follows:
* Suppose we apply FSDP to the root module (e.g. the ``Transformer`` class). Suppose we further apply FSDP to each transformer block (e.g. the ``TransformerBlock`` class).
* Most commonly, the embedding and final linear projection are direct children of the root ``Transformer`` class.
* Following our rule, that means that the embedding and final linear projection are assigned to the root ``Transformer``'s flat parameter.
* We have _another_ special rule, which is that the root does not free its parameters after forward because they will be anyways immediately all-gathered in backward.
* Putting this together, this means that the root's flat parameter including the embedding and final projection are all-gathered to begin forward and kept in GPU memory until the end of backward.
* If the embedding and final linear are not weight-tied, then we _could_ further apply FSDP to the embedding and to the final linear. For weight-tied parameters, we require them to be part of the same flat parameter (or else it would get double-counted). That would allow the embedding to be freed after its usage in forward and only all-gathered toward the end of backward.
* Hopefully, this gives a better sense -- each FSDP module gets assigned parameters in its ``module.parameters`` except those already assigned to another nested FSDP module, and the FSDP module's ``forward`` defines the 'live' interval for its parameters. Hence, the nested ``nn.Module`` structure can affect the all-gather/free schedule and hence the memory/throughput performance.
pFad - Phonifier reborn
Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies:
Alternative Proxy
pFad Proxy
pFad v3 Proxy
pFad v4 Proxy