AWS Machine Learning Blog

Amazon SageMaker model parallel library now accelerates PyTorch FSDP workloads by up to 20%

Large language model (LLM) training has surged in popularity over the last year with the release of several popular models such as Llama 2, Falcon, and Mistral. Customers are now pre-training and fine-tuning LLMs ranging from 1 billion to over 175 billion parameters to optimize model performance for applications across industries, from healthcare to finance and marketing.

Training performant models at this scale can be a challenge. Highly accurate LLMs can require terabytes of training data and thousands or even millions of hours of accelerator compute time to achieve target accuracy. To complete training and launch products in a timely manner, customers rely on parallelism techniques to distribute this enormous workload across up to thousands of accelerator devices. However, these parallelism techniques can be difficult to use: different techniques and libraries are only compatible with certain workloads or restricted to certain model architectures, training performance can be highly sensitive to obscure configurations, and the state of the art is quickly evolving. As a result, machine learning practitioners must spend weeks of preparation to scale their LLM workloads to large clusters of GPUs.

In this post, we highlight new features of the Amazon SageMaker model parallel (SMP) library that simplify the large model training process and help you train LLMs faster. In particular, we cover the SMP library’s new simplified user experience that builds on open source PyTorch Fully Sharded Data Parallel (FSDP) APIs, expanded tensor parallel functionality that enables training models with hundreds of billions of parameters, and performance optimizations that reduce model training time and cost by up to 20%.

To learn more about the SageMaker model parallel library, refer to SageMaker model parallelism library v2 documentation. You can also refer to our example notebooks to get started.

New features that simplify and accelerate large model training

This post discusses the latest features included in the v2.0 release of the SageMaker model parallel library. These features improve the usability of the library, expand functionality, and accelerate training. In the following sections, we summarize the new features and discuss how you can use the library to accelerate your large model training.

Aligning SMP with open source PyTorch

Since its launch in 2020, SMP has enabled high-performance, large-scale training on SageMaker compute instances. With this latest major version release of SMP, the library simplifies the user experience by aligning its APIs with open source PyTorch.

PyTorch offers Fully Sharded Data Parallelism (FSDP) as its main method for supporting large training workload across many compute devices. As demonstrated in the following code snippet, SMP’s updated APIs for techniques such as sharded data parallelism mirror those of PyTorch. You can simply run import torch.sagemaker and use it in place of torch.

## training_script.py
import torch.sagemaker as tsm
tsm.init()

# Set up a PyTorch model
model = ...

# Wrap the PyTorch model using the PyTorch FSDP module
model = FSDP(
model,
...
)

optimizer = ...
...

With these updates to SMP’s APIs, you can now realize the performance benefits of SageMaker and the SMP library without overhauling your existing PyTorch FSDP training scripts. This paradigm also allows you to use the same code base when training on premises as on SageMaker, simplifying the user experience for customers who train in multiple environments.

For more information on how to enable SMP with your existing PyTorch FSDP training scripts, refer to Get started with SMP.

Integrating tensor parallelism to enable training on massive clusters

This release of SMP also expands PyTorch FSDP’s capabilities to include tensor parallelism techniques. One problem with using sharded data parallelism alone is that you can encounter convergence problems as you scale up your cluster size. This is because sharding parameters, gradients, and the optimizer state across data parallel ranks also increases your global batch size; on large clusters, this global batch size can be pushed beyond the threshold below which the model would converge. You need to incorporate an additional parallelism technique that doesn’t require an increase in global batch size as you scale your cluster.

To mitigate this problem, SMP v2.0 introduces the ability to compose sharded data parallelism with tensor parallelism. Tensor parallelism allows the cluster size to increase without changing the global batch size or affecting model convergence. With this feature, you can safely increase training throughput by provisioning clusters with 256 nodes or more.

Today, tensor parallelism with PyTorch FSDP is only available with SMP v2. SMP v2 allows you to enable this technique with a few lines of code change and unlock stable training even on large clusters. SMP v2 integrates with Transformer Engine for its implementation of tensor parallelism and makes it compatible with the PyTorch FSDP APIs. You can enable PyTorch FSDP and SMP tensor parallelism simultaneously without making any changes to your PyTorch model or PyTorch FSDP configuration. The following code snippets show how to set up the SMP configuration dictionary in JSON format and add the SMP initialization module torch.sagemaker.init(), which accepts the configuration dictionary in the backend when you start the training job, to your training script.

The SMP configuration is as follows:

{
"tensor_parallel_degree": 8,
"tensor_parallel_seed": 0
}

In your training script, use the following code:

import torch.sagemaker as tsm
tsm.init()

from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_config(..)
model = tsm.transform(model)

To learn more about using tensor parallelism in SMP, refer to the tensor parallelism section of our documentation.

Use advanced features to accelerate model training by up to 20%

In addition to enabling distributed training on clusters with hundreds of instances, SMP also offers optimization techniques that can accelerate model training by up to 20%. In this section, we highlight a few of these optimizations. To learn more, refer to the core features section of our documentation.

Hybrid sharding

Sharded data parallelism is a memory-saving distributed training technique that splits the state of a model (model parameters, gradients, and optimizer states) across devices. This smaller memory footprint allows you to fit a larger model into your cluster or increase the batch size. However, sharded data parallelism also increases the communication requirements of your training job because the sharded model artifacts are frequently gathered from different devices during training. In this way, the degree of sharding is an important configuration that trades off memory consumption and communication overhead.

By default, PyTorch FSDP shards model artifacts across all of the accelerator devices in your cluster. Depending on your training job, this method of sharding could increase communication overhead and create a bottleneck. To help with this, the SMP library offers configurable hybrid sharded data parallelism on top of PyTorch FSDP. This feature allows you to set the degree of sharding that is optimal for your training workload. Simply specify the degree of sharding in a configuration JSON object and include it in your SMP training script.

The SMP configuration is as follows:

{ "hybrid_shard_degree": 16 }

To learn more about the advantages of hybrid sharded data parallelism, refer to Near-linear scaling of gigantic-model training on AWS. For more information on implementing hybrid sharding with your existing FSDP training script, see hybrid shared data parallelism in our documentation.

Use the SMDDP collective communication operations optimized for AWS infrastructure

You can use the SMP library with the SageMaker distributed data parallelism (SMDDP) library to accelerate your distributed training workloads. SMDDP includes an optimized AllGather collective communication operation designed for best performance on SageMaker p4d and p4de accelerated instances. In distributed training, collective communication operations are used to synchronize information across GPU workers. AllGather is one of the core collective communication operations typically used in sharded data parallelism to materialize the layer parameters before forward and backward computation steps. For training jobs that are bottlenecked by communication, faster collective operations can reduce training time and cost with no side effects on convergence.

To use the SMDDP library, you only need to add two lines of code to your training script:

import torch.distributed as dist

# Initialize with SMDDP
import smdistributed.dataparallel.torch.torch_smddp
dist.init_process_group(backend="smddp") # Replacing "nccl"

# Initialize with SMP
import torch.sagemaker as tsm
tsm.init()

In addition to SMP, SMDDP supports open source PyTorch FSDP and DeepSpeed. To learn more about the SMDDP library, see Run distributed training with the SageMaker distributed data parallelism library.

Activation offloading

Typically, the forward pass of model training computes activations at each layer and keeps them in GPU memory until the backward pass for the corresponding layer finishes. These stored activations can consume significant GPU memory during training. Activation offloading is a technique that instead moves these tensors to CPU memory after the forward pass and later fetches them back to GPU when they are needed. This approach can substantially reduce GPU memory usage during training.

Although PyTorch supports activation offloading, its implementation is inefficient and can cause GPUs to be idle while activations are fetched back from CPU during a backward pass. This can cause significant performance degradation when using activation offloading.

SMP v2 offers an optimized activation offloading algorithm that can improve training performance. SMP’s implementation pre-fetches activations before they are needed on the GPU, reducing idle time.

Because SMP is built on top of PyTorch’s APIs, enabling optimized activation offloading requires just a few lines of code change. Simply add the associated configurations (sm_activation_offloading and activation_loading_horizon parameters) and include them in your training script.

The SMP configuration is as follows:

{
"activation_loading_horizon": 2,
"sm_activation_offloading": True
}

In the training script, use the following code:

import torch.sagemaker as tsm
tsm.init()

# Native PyTorch module for activation offloading
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
offload_wrapper,
)

model = FSDP(...)

# Activation offloading requires activation checkpointing.
apply_activation_checkpointing(
model,
check_fn=checkpoint_tformer_layers_policy,
)

model = offload_wrapper(model)

To learn more about the open source PyTorch checkpoint tools for activation offloading, see the checkpoint_wrapper.py script in the PyTorch GitHub repository and Activation Checkpointing in the PyTorch blog post Scaling Multimodal Foundation Models in TorchMultimodal with Pytorch Distributed. To learn more about SMP’s optimized implementation of activation offloading, see the activation offloading section of our documentation.

Beyond hybrid sharding, SMDDP, and activation offloading, SMP offers additional optimizations that can accelerate your large model training workload. This includes optimized activation checkpointing, delayed parameter initialization, and others. To learn more, refer to the core features section of our documentation.

Conclusion

As datasets, model sizes, and training clusters continue to grow, efficient distributed training is increasingly critical for timely and affordable model and product delivery. The latest release of the SageMaker model parallel library helps you achieve this by reducing code change and aligning with PyTorch FSDP APIs, enabling training on massive clusters via tensor parallelism and optimizations that can reduce training time by up to 20%.

To get started with SMP v2, refer to our documentation and our sample notebooks.


About the Authors

Robert Van Dusen is a Senior Product Manager with Amazon SageMaker. He leads frameworks, compilers, and optimization techniques for deep learning training.

Luis Quintela is the Software Developer Manager for the AWS SageMaker model parallel library. In his spare time, he can be found riding his Harley in the SF Bay Area.

Gautam Kumar is a Software Engineer with AWS AI Deep Learning.  He is passionate about building tools and systems for AI. In his spare time, he enjoy biking and reading books.

Rahul Huilgol is a Senior Software Development Engineer in Distributed Deep Learning at Amazon Web Services.