Note
Click here to download the full example code
Getting Started with Alpa
Alpa is a library that automatically parallelizes jax functions and runs them on a distributed cluster. Alpa analyses the jax computational graph and generates a distributed execution plan tailored for the computational graph and target cluster. Alpa is specifically designed for training large-scale neural networks. The generated execution plan can combine state-of-the-art distributed training techniques including data parallelism, operator parallelism, and pipeline parallelism.
Alpa provides a simple API @parallelize
and automatically generates the best execution
plan by solving optimization problems. Therefore, you can efficiently scale your jax
computation on a distributed cluster, without any expertise in distributed computing.
In this tutorial, we show the usage of Alpa with an MLP example.
Import Libraries
We first import the required libraries. Flax and optax are libraries on top of jax for training neural networks. Although we use these libraries in this example, Alpa works on jax’s and XLA’s internal intermediate representations and does not depend on any specific high-level libraries.
from functools import partial
import alpa
from alpa.testing import assert_allclose
from flax import linen as nn
from flax.training.train_state import TrainState
import jax
import jax.numpy as jnp
from jax import random
import numpy as np
import optax
Train an MLP on a Single Device
To begin with, we implement the model and training loop on a single device. We will parallelize it later. We train an MLP to learn a function y = Wx + b.
class MLPModel(nn.Module):
hidden_dim: int
num_layers: int
@nn.compact
def __call__(self, x):
for i in range(self.num_layers):
if i % 2 == 0:
x = nn.Dense(features=self.hidden_dim * 4)(x)
else:
x = nn.Dense(features=self.hidden_dim)(x)
x = nn.relu(x)
return x
dim = 2048
batch_size = 2048
num_layers = 10
# Generate ground truth W and b
rngkey = jax.random.PRNGKey(0)
k1, k2 = random.split(rngkey)
W = random.normal(k1, (dim, dim))
b = random.normal(k2, (dim,))
# Generate the training data
ksample, knoise = random.split(k1)
x = random.normal(ksample, (batch_size, dim))
y = (x @ W + b) + 0.1 * random.normal(knoise, (batch_size, dim))
# Initialize a train state, which includes the model paramter and optimizer state.
model = MLPModel(hidden_dim=dim, num_layers=num_layers)
params = model.init(rngkey, x)
tx = optax.adam(learning_rate=1e-3)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
# Define the training function and execute one step
def train_step(state, batch):
def loss_func(params):
out = state.apply_fn(params, batch["x"])
loss = jnp.mean((out - batch["y"])**2)
return loss
grads = jax.grad(loss_func)(state.params)
new_state = state.apply_gradients(grads=grads)
return new_state
batch = {"x": x, "y": y}
expected_state = train_step(state, batch)
Out:
/python3.8-env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:1866: UserWarning: Explicitly requested dtype <class 'numpy.float64'> requested in asarray is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
lax_internal._check_user_dtype_supported(dtype, "asarray")
Auto-parallelization with @parallelize
Alpa provides a transformation @alpa.parallelize
to parallelize a jax function.
@alpa.parallelize
is similar to @jax.jit
. @jax.jit
compiles a jax
function for a single device, while @alpa.parallelize
compiles a jax function
for a distributed device cluster.
You may know that jax has some built-in transformations for parallelization,
such as pmap
, pjit
, and xmap
. However, these transformations are not
fully automatic, because they require users to manually specify the parallelization
strategies such as parallelization axes and device mapping schemes. You also need to
manually call communication primitives such as lax.pmean
and lax.all_gather
,
which is nontrivial if you want to do advanced model parallelization.
Unlike these transformations, @alpa.parallelize
can do all things automatically for
you. @alpa.parallelize
finds the best parallelization strategy for the given jax
function and does the code tranformation. You only need to write the code as if you are
writing for a single device.
# Define the training step. The body of this function is the same as the
# ``train_step`` above. The only difference is to decorate it with
# ``@alpa.paralellize``.
@alpa.parallelize
def alpa_train_step(state, batch):
def loss_func(params):
out = model.apply(params, batch["x"])
loss = jnp.mean((out - batch["y"])**2)
return loss
grads = jax.grad(loss_func)(state.params)
new_state = state.apply_gradients(grads=grads)
return new_state
# Test correctness
actual_state = alpa_train_step(state, batch)
assert_allclose(expected_state.params, actual_state.params, atol=5e-3)
After being decorated by @parallelize
, the function can still take numpy
arrays or jax arrays as inputs. The function will first distribute the input
arrays into correct devices according to the parallelization strategy and then
execute the function distributedly. The returned result arrays are also
stored distributedly.
print("Input parameter type:", type(state.params["params"]["Dense_0"]["kernel"]))
print("Output parameter type:", type(actual_state.params["params"]["Dense_0"]["kernel"]))
# Get one copy
kernel_np = np.array(actual_state.params["params"]["Dense_0"]["kernel"])
Out:
Input parameter type: <class 'jaxlib.xla_extension.DeviceArray'>
Output parameter type: <class 'jaxlib.xla_extension.pmap_lib.ShardedDeviceArray'>
Execution Speed Comparison
By parallelizing a jax function, we can accelerate the computation and reduce
the memory usage per GPU, so we can train large models faster.
We benchmark the execution speed of @jax.jit
and @alpa.parallelize
on a 8-GPU machine.
state = actual_state # We need this assignment because the original `state` is "donated" and freed.
from alpa.util import benchmark_func
# Benchmark serial execution with jax.jit
jit_train_step = jax.jit(train_step, donate_argnums=(0,))
def sync_func():
jax.local_devices()[0].synchronize_all_activity()
def serial_execution():
global state
state = jit_train_step(state, batch)
costs = benchmark_func(serial_execution, sync_func, warmup=5, number=10, repeat=5) * 1e3
print(f"Serial execution time. Mean: {np.mean(costs):.2f} ms, Std: {np.std(costs):.2f} ms")
# Benchmark parallel execution
# We distribute arguments in advance for the benchmarking purpose
state, batch = alpa_train_step.preshard_dynamic_args(state, batch)
def parallel_execution():
global state
state = alpa_train_step(state, batch)
costs = benchmark_func(parallel_execution, sync_func, warmup=5, number=10, repeat=5) * 1e3
print(f"Parallel execution time. Mean: {np.mean(costs):.2f} ms, Std: {np.std(costs):.2f} ms")
Out:
Serial execution time. Mean: 211.57 ms, Std: 8.99 ms
Parallel execution time. Mean: 71.98 ms, Std: 0.76 ms
Memory Usage Comparison
We can also compare the memory usage per GPU.
GB = 1024 ** 3
executable = jit_train_step.lower(state, batch).compile().runtime_executable()
print(f"Serial execution per GPU memory usage: {executable.total_allocation_size() / GB:.2f} GB")
executable = alpa_train_step.get_executable(state, batch)
print(f"Parallel execution per GPU memory usage: {executable.get_total_allocation_size() / GB:.2f} GB")
Out:
Serial execution per GPU memory usage: 2.78 GB
Parallel execution per GPU memory usage: 0.50 GB
Total running time of the script: ( 0 minutes 52.369 seconds)