Distributed Training with Both Shard and Pipeline Parallelism

Alpa can automatically parallelizes jax functions with both shard parallelism (a.k.a. intra-operator parallelism) and pipeline parallelism (a.k.a. inter-operator parallelism). Shard parallelism includes data parallelism, operator parallelism, and their combinations. The quick start focuses on using Alpa for shard parallelism.

In this tutorial, we show how to use Alpa with both shard and pipeline parallelism. First, we show how to use Alpa to manually assign stages for pipeline parallelism. Then we show how to use Alpa to automate this process.

Import Libraries and Initialize Environment

First, import the required libraries.

import alpa
from alpa.testing import assert_allclose
import copy
from flax import linen as nn
from flax.training.train_state import TrainState
import jax
import jax.numpy as jnp
from jax import random
import optax
import ray

alpa.util.disable_tqdm_globally()

Connect to a Ray Cluster

Alpa uses a distributed framework ray to manage the cluster and disributed workers. We initialize ray and alpa.

ray.init()
alpa.init(cluster="ray")

# Alternatively, you can use the following command to connect to an existing
# ray cluster.
# ray.init(address="auto")
#
# Note: `alpa.init(cluster="ray")` uses the gpus resources of the whole ray
# cluster. To configure Alpa to only use a subset of gpu resources, one can
# specific the number of nodes and number of gpus per node.
# For example, only run 2 gpus when 8 gpus are available
# alpa.init('ray', devices_per_node=2, num_nodes=1)

Out:

2022-09-08 01:32:25,858 INFO services.py:1333 -- View the Ray dashboard at http://127.0.0.1:8266

Train an MLP on a Single Device

In this tutorial, we use a toy dataset to train an MLP model. Specifically, we use the model to fit the function: \(y = Wx + b\). Note that now this model is being executed on CPU because we force the driver process to use the CPU.

class MLPModel(nn.Module):
    hidden_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim * 4)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim * 4)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        return x


dim = 2048
batch_size = 2048

# Generate ground truth W and b
rngkey = jax.random.PRNGKey(0)
k1, k2 = random.split(rngkey)
W = random.normal(k1, (dim, dim))
b = random.normal(k2, (dim,))

# Generate the training data
ksample, knoise = random.split(k1)
x = random.normal(ksample, (batch_size, dim))
y = (x @ W + b) + 0.1 * random.normal(knoise, (batch_size, dim))

# Initialize a train state, which includes the model paramter and optimizer
# state.
model = MLPModel(hidden_dim=dim)
params = model.init(rngkey, x)
tx = optax.adam(learning_rate=1e-3)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)


# Define the training step
def train_step(state, batch):

    def loss_func(params):
        out = model.apply(params, batch["x"])
        loss = jnp.mean((out - batch["y"])**2)
        return loss

    grads = jax.grad(loss_func)(state.params)
    new_state = state.apply_gradients(grads=grads)
    return new_state


batch = {"x": x, "y": y}
expected_state = train_step(state, batch)

Pipeline Parallelism with Manual Assignment

Pipeline paralleism requires partitioning the model into several pipeline stages. To manually assign stages, we can use alpa.mark_pipeline_boundary to mark the boundary of each pipeline stage in the forward function. Note that each pipeline stage is also automatically parallelized by the shard parallel pass.

# Define a MLP model with manual stage boundaries.
class ManualPipelineMLPModel(nn.Module):
    hidden_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim * 4)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        # Use this boundary marker to separate the model into two stages.
        alpa.mark_pipeline_boundary()
        x = nn.Dense(features=self.hidden_dim * 4)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        return x


# Initialize the train state with the same parameters as the single-device
# model.
manual_pipeline_model = ManualPipelineMLPModel(hidden_dim=dim)
manual_pipeline_state = TrainState.create(apply_fn=manual_pipeline_model.apply,
                                          params=copy.deepcopy(params),
                                          tx=tx)


# Define the training step.
# We use the "alpa.PipeshardParallel" option to let alpa use both
# pipeline parallelism and shard parallelism. To make pipeline parallelism
# efficient, we need to fill the pipeline with many micro batches,
# so a `num_micro_batches` should be specified.
@alpa.parallelize(method=alpa.PipeshardParallel(num_micro_batches=16,
                                                layer_option="manual"))
def manual_pipeline_train_step(state, batch):

    def loss_func(params):
        out = state.apply_fn(params, batch["x"])
        loss = jnp.mean((out - batch["y"])**2)
        return loss

    # We use `alpa.grad` here to seperate the apply gradient stage with the
    # forward/backward stages in the pipeline. This is necessary to ensure that
    # the gradient accumulation is correct.
    grads = alpa.grad(loss_func)(state.params)
    new_state = state.apply_gradients(grads=grads)
    return new_state


manual_pipeline_actual_state = manual_pipeline_train_step(
    manual_pipeline_state, batch)
assert_allclose(expected_state.params,
                manual_pipeline_actual_state.params,
                atol=5e-3)

alpa.shutdown()

Note

In addition, Alpa supports more flexible manual assignments of pipeline parallelism strategies. In the above example, each partitioned stages will be assigned an equal number of devices to run. If you want to control the device assignment of each stage, you can use the more advanced stage_option=alpa.ManualStageOption.

Pipeline Parallelism with Automatic Assignment

Alpa also supports automatically partitioning the model into multiple pipeline stages and assign each pipeline stage a device mesh such that the total execution latency is minimized. Specifically, the automatic partitioning algorithm consists of the following steps:

  1. Layer Construction: In this step, the operators in the model are clustered into “layers” based on a graph clustering algorithm. The user needs to specify the total number of layers (i.e. clusters) as a hyperparameter.

  2. Stage Construction and Mesh Slicing: In this step, we partition the device cluster (device mesh) to multiple submeshes and assign layers to submeshes to form pipeline stages to minimize the total pipeline execution latency.

alpa.init(cluster="ray")

# Define the parallel method.
# `alpa.AutoLayerOption(layer_num=2)` means we use the auto layer construcion
# algorithm to cluster primitive operators into two layers.
# `stage_option="auto"` means we enable the auto stage construction algorithm.
method = alpa.PipeshardParallel(num_micro_batches=16,
                                layer_option=alpa.AutoLayerOption(layer_num=2),
                                stage_option="auto")


# Define the training step. The function body is the same as the above one.
@alpa.parallelize(method=method)
def auto_pipeline_train_step(state, batch):

    def loss_func(params):
        out = state.apply_fn(params, batch["x"])
        loss = jnp.mean((out - batch["y"])**2)
        return loss

    # Again, we use `alpa.grad` here to seperate the apply gradient stage with
    # the forward/backward stages in the pipeline.
    grads = alpa.grad(loss_func)(state.params)
    new_state = state.apply_gradients(grads=grads)
    return new_state


# In the first call, alpa triggers the compilation.
# The compilation first profiles several costs and solves an optimization
# problem to get the optimal pipeline assignments.
auto_pipeline_actual_state = auto_pipeline_train_step(state, batch)
assert_allclose(expected_state.params,
                auto_pipeline_actual_state.params,
                atol=5e-3)

alpa.shutdown()

Out:

-------------------- Automatic stage clustering --------------------
submesh_choices: ((1, 1), (1, 2), (1, 4), (1, 8))
- Profiling for submesh 3 (1, 8):
- Generate all stage infos (Jaxpr -> HLO)
- Compile all stages
- Profile all stages
cost[0, 1, 0]=0.004, max_n_succ_stage=4096, Mem: avail=13.664GB, peak=1.001GB, intermediate=0.000GB, init=0.063GB, as_config=((8, 1), {'force_batch_dim_to_mesh_dim': 0})
cost[0, 1, 1]=0.003, max_n_succ_stage=4096, Mem: avail=13.664GB, peak=0.501GB, intermediate=0.000GB, init=0.109GB, as_config=((4, 2), {'force_batch_dim_to_mesh_dim': 0})
cost[0, 1, 2]=0.002, max_n_succ_stage=4096, Mem: avail=13.664GB, peak=0.251GB, intermediate=0.000GB, init=0.078GB, as_config=((2, 4), {'force_batch_dim_to_mesh_dim': 0})
cost[0, 1, 3]=0.002, max_n_succ_stage=4096, Mem: avail=13.664GB, peak=0.098GB, intermediate=0.000GB, init=0.063GB, as_config=((1, 8), {'force_batch_dim_to_mesh_dim': 0})
cost[0, 1, 4]=0.001, max_n_succ_stage=4096, Mem: avail=13.664GB, peak=0.098GB, intermediate=0.000GB, init=0.063GB, as_config=((8, 1), {})
Profiling for submesh 3 (1, 8) takes 32.90 seconds
Profiled costs are: [[[       inf        inf        inf        inf        inf]
  [0.00438674 0.00278576 0.00198701 0.00153695 0.00149446]]

 [[       inf        inf        inf        inf        inf]
  [       inf        inf        inf        inf        inf]]]
Profiled max_n_succ_stages are: [[[  -1   -1   -1   -1   -1]
  [4096 4096 4096 4096 4096]]

 [[  -1   -1   -1   -1   -1]
  [  -1   -1   -1   -1   -1]]]
--------------------------------------------------
- Profiling for submesh 2 (1, 4):
- Generate all stage infos (Jaxpr -> HLO)
- Compile all stages
- Profile all stages
cost[0, 0, 1]=0.002, max_n_succ_stage=4096, Mem: avail=13.664GB, peak=0.252GB, intermediate=0.001GB, init=0.063GB, as_config=((2, 2), {'force_batch_dim_to_mesh_dim': 0})
cost[0, 0, 2]=0.001, max_n_succ_stage=4096, Mem: avail=13.664GB, peak=0.099GB, intermediate=0.001GB, init=0.063GB, as_config=((1, 4), {'force_batch_dim_to_mesh_dim': 0})
cost[0, 0, 3]=0.001, max_n_succ_stage=4096, Mem: avail=13.664GB, peak=0.099GB, intermediate=0.001GB, init=0.063GB, as_config=((4, 1), {})
cost[0, 0, 0]=0.002, max_n_succ_stage=4096, Mem: avail=13.664GB, peak=0.501GB, intermediate=0.001GB, init=0.063GB, as_config=((4, 1), {'force_batch_dim_to_mesh_dim': 0})
cost[0, 1, 1]=0.003, max_n_succ_stage=3873, Mem: avail=13.664GB, peak=0.501GB, intermediate=0.003GB, init=0.156GB, as_config=((2, 2), {'force_batch_dim_to_mesh_dim': 0})
cost[0, 1, 0]=0.004, max_n_succ_stage=4096, Mem: avail=13.664GB, peak=1.001GB, intermediate=0.003GB, init=0.125GB, as_config=((4, 1), {'force_batch_dim_to_mesh_dim': 0})
cost[0, 1, 2]=0.002, max_n_succ_stage=3363, Mem: avail=13.664GB, peak=0.193GB, intermediate=0.004GB, init=0.125GB, as_config=((1, 4), {'force_batch_dim_to_mesh_dim': 0})
cost[0, 1, 3]=0.002, max_n_succ_stage=3363, Mem: avail=13.664GB, peak=0.193GB, intermediate=0.004GB, init=0.125GB, as_config=((4, 1), {})
cost[1, 1, 0]=0.002, max_n_succ_stage=4096, Mem: avail=13.664GB, peak=0.501GB, intermediate=0.002GB, init=0.063GB, as_config=((4, 1), {'force_batch_dim_to_mesh_dim': 0})
cost[1, 1, 2]=0.002, max_n_succ_stage=4096, Mem: avail=13.664GB, peak=0.097GB, intermediate=0.002GB, init=0.063GB, as_config=((1, 4), {'force_batch_dim_to_mesh_dim': 0})
cost[1, 1, 1]=0.002, max_n_succ_stage=4096, Mem: avail=13.664GB, peak=0.251GB, intermediate=0.002GB, init=0.094GB, as_config=((2, 2), {'force_batch_dim_to_mesh_dim': 0})
cost[1, 1, 3]=0.002, max_n_succ_stage=4096, Mem: avail=13.664GB, peak=0.097GB, intermediate=0.002GB, init=0.063GB, as_config=((4, 1), {})
Profiling for submesh 2 (1, 4) takes 20.67 seconds
Profiled costs are: [[[0.00232795 0.0019203  0.00115792 0.00105005        inf]
  [0.00445016 0.00313377 0.00202494 0.00203413        inf]]

 [[       inf        inf        inf        inf        inf]
  [0.00245696 0.0019856  0.00156432 0.00159595        inf]]]
Profiled max_n_succ_stages are: [[[4096 4096 4096 4096   -1]
  [4096 3873 3363 3363   -1]]

 [[  -1   -1   -1   -1   -1]
  [4096 4096 4096 4096   -1]]]
--------------------------------------------------
- Profiling for submesh 1 (1, 2):
- Generate all stage infos (Jaxpr -> HLO)
- Compile all stages
- Profile all stages
cost[0, 0, 1]=0.002, max_n_succ_stage=4096, Mem: avail=13.664GB, peak=0.194GB, intermediate=0.003GB, init=0.125GB, as_config=((1, 2), {'force_batch_dim_to_mesh_dim': 0})
cost[0, 0, 2]=0.001, max_n_succ_stage=4096, Mem: avail=13.664GB, peak=0.194GB, intermediate=0.003GB, init=0.125GB, as_config=((2, 1), {})
cost[0, 1, 1]=0.003, max_n_succ_stage=1939, Mem: avail=13.664GB, peak=0.383GB, intermediate=0.007GB, init=0.250GB, as_config=((1, 2), {'force_batch_dim_to_mesh_dim': 0})
cost[0, 1, 2]=0.003, max_n_succ_stage=1939, Mem: avail=13.664GB, peak=0.383GB, intermediate=0.007GB, init=0.250GB, as_config=((2, 1), {})
cost[1, 1, 1]=0.002, max_n_succ_stage=4096, Mem: avail=13.664GB, peak=0.192GB, intermediate=0.003GB, init=0.125GB, as_config=((1, 2), {'force_batch_dim_to_mesh_dim': 0})
cost[0, 0, 0]=0.003, max_n_succ_stage=4096, Mem: avail=13.664GB, peak=0.502GB, intermediate=0.003GB, init=0.125GB, as_config=((2, 1), {'force_batch_dim_to_mesh_dim': 0})
cost[1, 1, 2]=0.002, max_n_succ_stage=4096, Mem: avail=13.664GB, peak=0.192GB, intermediate=0.003GB, init=0.125GB, as_config=((2, 1), {})
cost[0, 1, 0]=0.005, max_n_succ_stage=2032, Mem: avail=13.664GB, peak=1.001GB, intermediate=0.006GB, init=0.250GB, as_config=((2, 1), {'force_batch_dim_to_mesh_dim': 0})
cost[1, 1, 0]=0.003, max_n_succ_stage=4096, Mem: avail=13.664GB, peak=0.502GB, intermediate=0.003GB, init=0.125GB, as_config=((2, 1), {'force_batch_dim_to_mesh_dim': 0})
Profiling for submesh 1 (1, 2) takes 15.02 seconds
Profiled costs are: [[[0.00256801 0.00162239 0.00148469        inf        inf]
  [0.00510118 0.00306678 0.00297202        inf        inf]]

 [[       inf        inf        inf        inf        inf]
  [0.00279755 0.00179476 0.00178525        inf        inf]]]
Profiled max_n_succ_stages are: [[[4096 4096 4096   -1   -1]
  [2032 1939 1939   -1   -1]]

 [[  -1   -1   -1   -1   -1]
  [4096 4096 4096   -1   -1]]]
--------------------------------------------------
- Profiling for submesh 0 (1, 1):
- Generate all stage infos (Jaxpr -> HLO)
- Compile all stages
- Profile all stages
cost[0, 0, 1]=0.002, max_n_succ_stage=2540, Mem: avail=13.664GB, peak=0.384GB, intermediate=0.005GB, init=0.250GB, as_config=((1, 1), {})
cost[0, 0, 0]=0.002, max_n_succ_stage=2540, Mem: avail=13.664GB, peak=0.384GB, intermediate=0.005GB, init=0.250GB, as_config=((1, 1), {'force_batch_dim_to_mesh_dim': 0})
cost[1, 1, 1]=0.003, max_n_succ_stage=2134, Mem: avail=13.664GB, peak=0.383GB, intermediate=0.006GB, init=0.250GB, as_config=((1, 1), {})
cost[1, 1, 0]=0.003, max_n_succ_stage=2134, Mem: avail=13.664GB, peak=0.383GB, intermediate=0.006GB, init=0.250GB, as_config=((1, 1), {'force_batch_dim_to_mesh_dim': 0})
cost[0, 1, 1]=0.005, max_n_succ_stage=1014, Mem: avail=13.664GB, peak=0.763GB, intermediate=0.012GB, init=0.500GB, as_config=((1, 1), {})
cost[0, 1, 0]=0.005, max_n_succ_stage=1014, Mem: avail=13.664GB, peak=0.763GB, intermediate=0.012GB, init=0.500GB, as_config=((1, 1), {'force_batch_dim_to_mesh_dim': 0})
Profiling for submesh 0 (1, 1) takes 13.31 seconds
Profiled costs are: [[[0.00239089 0.00238567        inf        inf        inf]
  [0.00477897 0.00477968        inf        inf        inf]]

 [[       inf        inf        inf        inf        inf]
  [0.00275159 0.00274867        inf        inf        inf]]]
Profiled max_n_succ_stages are: [[[2540 2540   -1   -1   -1]
  [1014 1014   -1   -1   -1]]

 [[  -1   -1   -1   -1   -1]
  [2134 2134   -1   -1   -1]]]
--------------------------------------------------
Compute cost saved to: compute-cost-2022-09-08-01-34-24.npy
----------------------------------------------------------------------
Result forward_stage_layer_ids: [[0, 1]]
Result mesh_shapes: [(1, 8)]
Result logical_mesh_shapes: [(8, 1)]
Result autosharding_option_dicts: [{}]

Total running time of the script: ( 2 minutes 25.280 seconds)

Gallery generated by Sphinx-Gallery