Distributed Training: Train BART/T5 for Summarization using 🤗 Transformers and Amazon SageMaker
Distributed Training: Train BART/T5 for Summarization using 🤗 Transformers and Amazon SageMaker
In this tutorial, we will explore how to train a distributed sequence-to-sequence (Seq2Seq) model using the 🤗 Transformers library and Amazon SageMaker. We will focus on training a BART (Bidirectional and Auto-Regressive Transformers) model for abstractive text summarization on the SAMSum dataset.
Introduction
Amazon SageMaker is a fully managed service that provides every developer and data scientist with the ability to build, train, and deploy machine learning (ML) models quickly. SageMaker removes the heavy lifting from each step of the machine learning process to make it easier to develop high-quality models.
The 🤗 Transformers library is a popular open-source library for natural language processing (NLP) tasks, including text classification, sentiment analysis, and language modeling. It provides a wide range of pre-trained models and a simple interface for fine-tuning them on specific tasks.
Tutorial Overview
In this tutorial, we will:
- Set up a development environment and install the required dependencies.
- Choose a 🤗 Transformers example/script for fine-tuning a model on the summarization task.
- Configure distributed training and hyperparameters.
- Create a Hugging Face estimator and start training.
- Upload the fine-tuned model to huggingface.co and test it with the Hosted Inference widget.
Set up a Development Environment and Install Dependencies
To start, we need to set up a development environment and install the required dependencies. We will use a SageMaker Notebook Instance for running our training job. You can learn how to set up a Notebook Instance here.
import sagemaker
sess = sagemaker.Session()
role = sagemaker.get_execution_role()
print(f"IAM role arn used for running training: {role}")
print(f"S3 bucket used for storing artifacts: {sess.default_bucket()}")
Next, we need to install the required dependencies, including the 🤗 Transformers library and the datasets library.
!pip install transformers "datasets[s3]" sagemaker --upgrade
We also need to install git-lfs for model upload.
!curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.rpm.sh | sudo bash
!sudo yum install git-lfs -y
!git lfs install
Choose a 🤗 Transformers Example/Script
The 🤗 Transformers repository contains several examples/scripts for fine-tuning models on tasks from language-modeling to token-classification. In our case, we are using the run_summarization.py script from the seq2seq/examples directory.
git_config = {
'repo': 'https://github.com/huggingface/transformers.git',
'branch': 'v4.4.2'
}
Configure Distributed Training and Hyperparameters
Next, we need to define our hyperparameters and configure our distributed training strategy. We can define any Seq2SeqTrainingArguments and the ones defined in run_summarization.py.
hyperparameters = {
'per_device_train_batch_size': 4,
'per_device_eval_batch_size': 4,
'model_name_or_path': 'facebook/bart-large-cnn',
'dataset_name': 'samsum',
'do_train': True,
'do_predict': True,
'predict_with_generate': True,
'output_dir': '/opt/ml/model',
'num_train_epochs': 3,
'learning_rate': 5e-5,
'seed': 7,
'fp16': True,
}
We also need to configure the distribution parameter for running training on smdistributed Data Parallel.
distribution = {
'smdistributed': {
'dataparallel': {
'enabled': True
}
}
}
Create a Hugging Face Estimator and Start Training
The last step before training is creating a Hugging Face estimator. The Estimator handles the end-to-end Amazon SageMaker training.
from sagemaker.huggingface import HuggingFace
huggingface_estimator = HuggingFace(
entry_point='run_summarization.py',
source_dir='./examples/seq2seq',
git_config=git_config,
instance_type='ml.p3dn.24xlarge',
instance_count=2,
transformers_version='4.4.2',
pytorch_version='1.6.0',
py_version='py36',
role=role,
hyperparameters=hyperparameters,
distribution=distribution
)
To start our training, we call the .fit() method.
huggingface_estimator.fit()
Upload the Fine-Tuned Model to Huggingface.co
After we uploaded our model, we can access it at https://huggingface.co/{hf_username}/{repository_name}. We can use the "Hosted Inference API" widget to test it.
from getpass import getpass
from huggingface_hub import HfApi, Repository
hf_username = "philschmid"
hf_email = "[email protected]"
repository_name = "bart-large-cnn-samsum"
password = getpass("Enter your password:")
token = HfApi().login(username=hf_username, password=password)
repo_url = HfApi().create_repo(token=token, name=repository_name, exist_ok=True)
model_repo = Repository(use_auth_token=token,
clone_from=repo_url,
local_dir='./model',
git_user=hf_username,
git_email=hf_email)
model_repo.push_to_hub()
Test Inference
After we uploaded our model, we can use the "Hosted Inference API" widget to test it.
print(f"https://huggingface.co/{hf_username}/{repository_name}")
This will give us a link to our model on huggingface.co, where we can test it using the Hosted Inference API.
Conclusion
In this tutorial, we trained a distributed sequence-to-sequence model using the 🤗 Transformers library and Amazon SageMaker. We fine-tuned a BART model for abstractive text summarization on the SAMSum dataset and uploaded the model to huggingface.co. We can use the "Hosted Inference API" widget to test the model.
Future Work
In the future, we can explore other distributed training strategies, such as data parallelism and model parallelism. We can also experiment with different hyperparameters and architectures to improve the performance of the model.
References
- 🤗 Transformers: https://huggingface.co/
- Amazon SageMaker: https://aws.amazon.com/sagemaker/
- SAMSum dataset: https://www.samsum.org/
Code
The code for this tutorial is available on GitHub: https://github.com/philschmid/transformers-sagemaker-tutorial
Source: https://huggingface.co/blog/sagemaker-distributed-training-seq2seq




