Alpa Compiler Walk-Through
This document provides a walk-through of the compiler part of Alpa.
Note
This document is based on the workflow as in this commit. While some specific details might not be the same as in the latest version, the general idea should be the same.
Starting from an arbitrary JAX function (i.e., computational graph) of a neural network training step, Alpa’s overall workflow includes the following steps:
Layer construction: Cluster different operators in the computational graph into a sequential list of pipeline layers.
Stage construction: Cluster the pipeline layers into pipeline stages and assign each stage a subset of devices for pipeline execution (i.e., inter-operator parallelism).
Auto sharding: Figure out how to shard each operator within each pipeline stage on its corresponding devices with SPMD parallelism (i.e., intra-operator parallelism).
Let’s start with the following code snippet:
class ManualPipelineMLPModel(nn.Module):
hidden_dim: int
@nn.compact
def __call__(self, x):
x = nn.Dense(features=self.hidden_dim * 4)(x)
x = nn.relu(x)
x = nn.Dense(features=self.hidden_dim)(x)
x = nn.relu(x)
# Use this boundary marker to separate the model into two stages.
alpa.mark_pipeline_boundary()
x = nn.Dense(features=self.hidden_dim * 4)(x)
x = nn.relu(x)
x = nn.Dense(features=self.hidden_dim)(x)
x = nn.relu(x)
return x
@alpa.parallelize(method=alpa.PipeshardParallel(num_micro_batches=16,
layer_option="manual"))
def manual_pipeline_train_step(state, batch):
def loss_func(params):
out = state.apply_fn(params, batch["x"])
loss = jnp.mean((out - batch["y"])**2)
return loss
# Use `alpa.grad` here to slice the forward/backward stages and the
# gradient update stage
grads = alpa.grad(loss_func)(state.params)
new_state = state.apply_gradients(grads=grads)
return new_state
Compared to original JAX/Flax, this code snippet additionally calls alpa.mark_pipeline
, alpa.parallelize
, and alpa.grad
. Below, we will show how Alpa uses these functions and decorators to compile the original single device computational graph into a distributed version.
Layer Construction
The first transformation we perform is in alpa.grad
(link)
for layer construction. It is a thin wrapper of the original jax.grad
in JAX,
which additionally performs the following tasks:
Process pipeline markers to form forward pipeline layers.
Call the original
jax.grad
. We directly use JAX’s autograd to map the forward layers to the backward layers.Mark all the gradients with a special marker so that we can perform gradient accumulation for them.
Mark all the operators after the gradient computation as the gradient update phase.
We form the pipeline layers by inserting pipeline markers into the JAX
automatically or manually with user annotations.
layer_option="manual"
in the code example above indicates that we
are inserting the markers manually.
The definition of pipeline markers can be found in
primitive_def.py.
We define a new JAX primitive pipeline_p
and an XLA custom call
pipeline_marker
. All these markers behave exactly the same as an
identity function that returns all the input
arguments.
We distinguish between start
and end
markers. The start
marker captures all the inputs to a pipeline layer, and the end
marker captures the outputs. To preserve the forward/backward
stage mapping, we set the gradient of a start
marker to be an end
marker, and the gradient of an end
to be a start
.
A complete pipeline layer has the following structure:
marked_inputs = pipeline_marker[type="start"] layer_inputs
...
layer_outputs = some_jax_operator marked_inputs
...
marked_outputs = pipeline_marker[type="end"] layer_outputs
Note that all the inputs of the JAX operators within the pipeline layer
should take the marked inputs or the intermediate results within the
layer. All the outputs of the layer will be marked by the end
marker.
In the manual case, we provide a simpler API that doesn’t require two
markers for a stage and the users do not need to specify the input and
output variables. Instead, the users only need to call
alpa.mark_pipeline_boundary
at the boundary of two pipeline layers.
The layer_level_jaxpr_transformation
function
(link)
will transform it to the above form.
Note: Alpa can also perform rematerialization (i.e., gradient checkpointing) at these pipeline stage boundaries. See these functions: link.
Stage Construction
The transformed function with layer markers is then transformed by
@alpa.parallelize
. The most important option of
@alpa.parallelize
is method
, which specifies which type of
parallelism to use. Here we set it to alpa.PipeshardParallel
,
indicating that we are using both pipeline parallelism (inter-operator
parallelism) and SPMD-shard parallelism (intra-operator parallelism).
@alpa.parallelize
transforms the original function to a
ParallelizedFunc
. ParallelizedFunc
is a Python class that
behaves like the original function but with some additional methods.
ParallelizedFunc
flattens the input arguments, and will compile the
JAX function according to the method
. In our case, it eventually
calls compile_pipeshard_executable()
here,
which transforms the input as follows:
compile_pipeshard_executable
first traces the original function to JAXPR. Note that we trace the function with both full batch size and the smaller micro-batch size for gradient accumulation. Then we call intocompile_pipeshard_executable_internal
.split_compute_grad_and_apply_grad
splits theapply_grad
part from the rest of the function. There is a special transformation for the case where a single parameterx
is used in multiple pipeline layersl1(x)
,l2(x)
, … For example in language models’ tied-embedding layer, the embedding matrix is used by both the first and the last stage. In this case, the backward pass of JAX will generate some equations that are not captured by pipeline markers to calculate the gradient tox
:grad_x = grad_l1_x + grad_l2_x
. We move these kinds of equations to theapply_grad
part and let each layer perform gradient accumulation separately.compute_grad_to_accumulate_grad
transforms the original acompute_grad
JAXPR that only computes gradient to anaccumulate_grad
JAXPR that performs gradient accumulation. More specifically, the structure ofaccumulate_grad
is shown in the following pseudo-code:def accumulate_grad(compute_grad_inputs, accumulated_grad): grad = compute_grad(compute_grad_inputs) accumulated_grad += grad return accumulated_grad
Note that the
+=
above is only correct when the gradients can be summed up. When the output is per input data (e.g., inference output), we useconcat
instead of+=
. The analysis of which operator to use is done in_get_full_batch_apply_grad
by comparing full-batch and micro-batch codes.slice_closed_jaxpr_by_full_pipeline_marks
slices theaccumulate_grad
JAXPR into many pipeline layers.mark_missing_vars_in_backward_computation_pipeline_marks
. When JAX derives the backward JAXPR, the backward layer will directly use the intermediate results of the forward layer instead of adding it to the backward layer’s start pipeline marker. This function fixes this issue. In addition, it removes allLiteral
in start markers and allDropVar
in end markers.cluster_layers_and_slice_mesh
performs stage construction. it clusters different pipeline layers into pipeline stages, slice the compute cluster represented as a 2D device mesh into many submeshes, and assign each stage a submesh. Right now, a forward layer and its corresponding backward layer will always be on the same submesh. See the full automatic algorithm in the Alpa paper.process_apply_gradient
splits the singleapply_grad
JAXPR into #submeshes parts, each part processes the gradient updates and optimizer states related to the variables on a specific submesh.create_donation_mapping
andsplit_donate_invars
: Process donated invars for each pipeline stage, and also add donation variables for gradient accumulation.
Auto Sharding
Then, in shard_each_stage
we run the auto-sharding pass for each
pipeline stage. Because we include distributed compilation for
different stages to accelerate the compilation, the code is nested here.
Specifically, the following two functions are the two most important ones:
In
generate_sharded_xla_computations_arguments
(code), we concat the JAXPRs of all stages on a submesh (which typically include forward/backward/update of a single stage) and compile it to anHLOModule
.Then we call
run_auto_sharding_pass
(code), which eventually callsRunAutoShardingPass
we wrote in XLA (code). This XLA function:First run a subset of XLA passes before SPMD partitioner.
Then we run the Alpa
AutoSharding
pass (code) that automatically annotate the graph with GSPMD annotations.Then run the
SliceAutoShardedStages
pass (code) that slices the concated stages back to individual stages, and return these stages back to Python.
The result of shard_each_stage
will be a list of SPMD sharded
pipeline stages. Then the whole pipeline and sharding execution schedule
will be summarized and organized via a PipelineInstEmitter
(code).
The result pipeshard_config
will be sent to the runtime to be
executed.
Note
To debug and visualize each step, you can debug via simply adding print instructions to the JAXPR in Python or the HLO in XLA.