Serving OPT-175B, BLOOM-176B and CodeGen-16B using Alpaο
This tutorial shows how to setup a serving system to serve one of the largest available pretrained language models OPT-175B. The instructions for other models (BLOOM and CodeGen) are also listed at the end.
π Try a live demo at Alpa-OPT Demo π
Overviewο
As a serving system, Alpa offers the following unique advantages:
Designed for large models: Cannot fit the model into a single GPU? Not a problem, Alpa is designed for training and serving big models like GPT-3.
Support commodity hardware: With Alpa, you can serve OPT-175B using your in-house GPU cluster, without needing the latest generations of A100 80GB GPUs nor fancy InfiniBand connections β no hardware constraints!
Flexible parallelism strategies: Alpa will automatically figure out the appropriate model-parallel strategies based on your cluster setup and your model architecture.
In this example, we use Alpa to serve the open-source OPT model, supporting all sizes ranging from 125M to 175B. Specifically, Alpa provides:
A distributed backend to perform efficient model-parallel inference for the large OPT models.
A web frontend to collect and batch inference requests from users.
Note
The pre-trained OPT model weights can be obtained from Metaseq, subject to their license.
Note
You will need at least 350GB GPU memory on your entire cluster to serve the OPT-175B model. For example, you can use 4 x AWS p3.16xlarge instances, which provide 4 (instance) x 8 (GPU/instance) x 16 (GB/GPU) = 512 GB memory.
You can also follow this guide to setup a serving system to serve smaller versions of OPT, such as OPT-66B, OPT-30B, etc. Pick an appropriate size from OPT weight downloading page based on your available resources.
Demoο
The code below shows how to use huggingface/transformers interface and Alpa distributed backend for large model inference.
from transformers import AutoTokenizer
from llm_serving.model.wrapper import get_model
# Load the tokenizer. All OPT models with different sizes share the same tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b")
tokenizer.add_bos_token = False
# Load the model. Alpa automatically downloads the weights to the specificed path
model = get_model(model_name="alpa/opt-2.7b", path="~/opt_weights/")
# Generate
prompt = "Paris is the capital city of"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
output = model.generate(input_ids=input_ids, max_length=256, do_sample=True)
generated_string = tokenizer.batch_decode(output, skip_special_tokens=True)
print(generated_string)
Requirementsο
Install Alpa following the installation guide. You can either install by python wheel or build from source.
Install additional requirements for
llm_serving
:
pip3 install "transformers<=4.23.1" fastapi uvicorn omegaconf jinja2 # Install torch corresponding to your CUDA version, e.g., for CUDA 11.3: pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
Clone the
alpa
repo. If you install alpa by python wheel, please clone the alpa repo. If you install from source, you already did this step.
git clone git@github.com:alpa-projects/alpa.git
Install
llm_serving
package. Go to the examples folder and install the package.
cd alpa/examples pip3 install -e .
Convert Weights Formatο
The weights of OPT 125Mβ66B models are publicly available. Huggingface hosts copies of these weights. For OPT 125Mβ66B, you do not need to download or convert the weights manually. Alpa will automatically download the weights from huggingface to the given path if Alpa cannot find cached weights locally.
The weights of OPT-175B can be got from meta by filling a request form . You then need to manually convert the obtained weights into Alpa format.
Convert OPT-175B weights into Alpa formatsο
We provide detailed instructions below on how to convert the original OPT-175B weights into Alpa-compatible formats. You can skip this section if you only want to run smaller models.
Note
The procedures below for converting OPT-175B weights will take about 1 hour.
- Download and verify the original weights
First, download Metaseqβs original OPT-175B weights in 992 shards, verify the MD5 of each shard , and put the shards under a folder, say,
PATH_TO_992_SHARDS/
.
- Consolidate the weights from 992 shards into one single checkpoint
Use the script step_2_consolidate_992_shards_to_singleton.py as:
python3 step_2_consolidate_992_shards_to_singleton.py --read-prefix [PATH_TO_992_SHARDS]/checkpoint_last --save-prefix [PATH_TO_SAVE_CHECKPOINT]The consolidated checkpoint will be saved at
PATH_TO_SAVE_CHECKPOINT
as specified in the command.Note
The above script will require a peak memory (RAM) usage as large as twice of the model size. For example, if you are performing consolidation for the 175B model, it will approximately have a peak memory usage of 175B x 2 bytes x 2 = 700GB. Please make sure your RAM is sufficient to run the script without throwing an OOM exception.
Note
The above script will save the model weights as a single consolidated checkpoint at
PATH_TO_SAVE_CHECKPOINT
, hence will require at least 350GB disk space available.
- Convert the single checkpoint into Alpa-compatible formats
Alpa ingests weights simply from numpy formats. Use the script step_3_convert_to_numpy_weights.py to convert the single checkpoint into numpy formats:
python3 step_3_convert_to_numpy_weights.py --ckpt-path PATH_TO_SAVE_CHECKPOINT --output-folder OUTPUT_PATH
The weights will be saved at the folder
OUTPUT_PATH
as specified in the command.
Note
The above script also requires 350GB free disk space to write the numpy-formatted weights.
Converted weights for other modelsο
You do not need to download the weights manually for OPT 125Mβ66B. However, if you have trouble with the automatic downloading or huggingface. We also provide the converted weights for the following models.
Copy Weights to Multiple Nodesο
If you want to run the model on multiple nodes, you can use one of the following methods to copy the weights to all nodes.
Put the weights under a shared network file system, so all nodes can access it.
Run the script first on a driver node. The driver node will download the weights to its local disk, but the script will fail later because worker nodes cannot access the weights. You can then manually copy all downloaded weights under
path
from the driver node to all worker nodes.
Run Generation in the Command Lineο
The code of this tutorial is under examples/llm_serving.
Run generation using the 125M model with PyTorch/HuggingFace backend on a single GPU:
python3 textgen.py --model facebook/opt-125m
Run generation using the 125M model with JAX backend on a single GPU:
python3 textgen.py --model jax/opt-125m
Run model-parallel generation using the 2.7B model with Alpa on multiple GPUs:
# Start ray on the node ray start --head python3 textgen.py --model alpa/opt-2.7b
Run distributed generation using the 175B model with Alpa on a cluster of GPU nodes. Note you will need >350GB total GPU memory in the entire cluster to successfully run the inference.
Before running the command below, start Ray on the cluster following this guide. You can check the cluster status by
ray status
. You should be able to see all GPUs and all nodes in the output.python3 textgen.py --model alpa/opt-175b
Launch a Web Server to Serve the OPT Modelsο
We need to run two scripts: one for web server and another for the model serving worker.
They will use two ports. The port of the website is defined in the command line and the port of the worker is defined in service/constants.py
# Launch the model worker
python3 launch_model_worker.py --model alpa/opt-175b
# Launch the website (in a new terminal)
uvicorn launch_website:app --host 0.0.0.0 --port 8001
Then open http://[IP-ADDRESS]:8001
in your browser to try out the model!
There is also a client library which can be used to query the model worker
via a python script. Please check test_completions.py
for the usage.
Improving Generation Speedο
Here are some tips for improving the generation speed.
Batching. Single sequence generation cannot fully utilize the GPU power. Applying batching can greatly boost the performace. See
textgen.py
for the usage.Tune the
encoder_chunk_sizes
argument ofget_model
. Alpa compiles multiple executables and uses these executables to encode a prompt chunk by chunk. This argument controls the possible chunk sizes. Depending on the length of your prompt, you can try different combinations. For example, if your prompt lengths are around 1000-1500, a good combination is[1, 256, 1024]
.Tune parallelization strategy. If you are familiar with alpa, you can tune the
method
argument ofalpa.parallelize
and try different parallelization methods.
If you find the generation speed too slow and want to accelerate it, please join Alpa slack and tell us your use cases. We are actively working on improving the performance.
OPT Licenseο
The use of the OPT pretrained weights is subject to the Model License by Metaseq.
Other Models (BLOOM)ο
Alpa also supports BLOOM. You can use commands similar to OPT but with a different model name.
# Huggingface/pytorch backend python3 textgen.py --model bigscience/bloom-560m # Jax backend python3 textgen.py --model jax/bloom-560m # Alpa backend python3 textgen.py --model alpa/bloom-560m
Other Models (CodeGen)ο
Alpa also supports CodeGen. You can use commands similar to OPT but with a different model name.
# Huggingface/pytorch backend python3 codegen.py --model Salesforce/codegen-2B-mono # Alpa backend python3 codegen.py --model alpa/codegen-2B-mono