Distributed Training with Both Shard and Pipeline Parallelism

Alpa can automatically parallelizes jax functions with both shard parallelism (a.k.a. intra-operator parallelism) and pipeline parallelism (a.k.a. inter-operator parallelism). Shard parallelism includes data parallelism, operator parallelism, and their combinations. The quick start focuses on using Alpa for shard parallelism.

In this tutorial, we show how to use Alpa with both shard and pipeline parallelism. First, we show how to use Alpa to manually assign stages for pipeline parallelism. Then we show how to use Alpa to automate this process.

Import Libraries and Initialize Environment

First, import the required libraries.

import alpa
from alpa.testing import assert_allclose
import copy
from flax import linen as nn
from flax.training.train_state import TrainState
import jax
import jax.numpy as jnp
from jax import random
import optax
import ray

alpa.util.disable_tqdm_globally()

Connect to a Ray Cluster

Alpa uses a distributed framework ray to manage the cluster and disributed workers. We initialize ray and alpa.

ray.init()
alpa.init(cluster="ray")

# Alternatively, you can use the following command to connect to an existing
# ray cluster.
# ray.init(address="auto")
#
# Note: `alpa.init(cluster="ray")` uses the gpus resources of the whole ray
# cluster. To configure Alpa to only use a subset of gpu resources, one can
# specific the number of nodes and number of gpus per node.
# For example, only run 2 gpus when 8 gpus are available
# alpa.init('ray', devices_per_node=2, num_nodes=1)

Out:

2023-01-14 09:50:50,122 INFO worker.py:1525 -- Started a local Ray instance. View the dashboard at http://127.0.0.1:8265

Train an MLP on a Single Device

In this tutorial, we use a toy dataset to train an MLP model. Specifically, we use the model to fit the function: \(y = Wx + b\). Note that now this model is being executed on CPU because we force the driver process to use the CPU.

class MLPModel(nn.Module):
    hidden_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim * 4)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim * 4)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        return x


dim = 2048
batch_size = 2048

# Generate ground truth W and b
rngkey = jax.random.PRNGKey(0)
k1, k2 = random.split(rngkey)
W = random.normal(k1, (dim, dim))
b = random.normal(k2, (dim,))

# Generate the training data
ksample, knoise = random.split(k1)
x = random.normal(ksample, (batch_size, dim))
y = (x @ W + b) + 0.1 * random.normal(knoise, (batch_size, dim))

# Initialize a train state, which includes the model paramter and optimizer
# state.
model = MLPModel(hidden_dim=dim)
params = model.init(rngkey, x)
tx = optax.adam(learning_rate=1e-3)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)


# Define the training step
def train_step(state, batch):

    def loss_func(params):
        out = model.apply(params, batch["x"])
        loss = jnp.mean((out - batch["y"])**2)
        return loss

    grads = jax.grad(loss_func)(state.params)
    new_state = state.apply_gradients(grads=grads)
    return new_state


batch = {"x": x, "y": y}
expected_state = train_step(state, batch)

Pipeline Parallelism with Manual Assignment

Pipeline paralleism requires partitioning the model into several pipeline stages. To manually assign stages, we can use alpa.mark_pipeline_boundary to mark the boundary of each pipeline stage in the forward function. Note that each pipeline stage is also automatically parallelized by the shard parallel pass.

# Define a MLP model with manual stage boundaries.
class ManualPipelineMLPModel(nn.Module):
    hidden_dim: int

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim * 4)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        # Use this boundary marker to separate the model into two stages.
        alpa.mark_pipeline_boundary()
        x = nn.Dense(features=self.hidden_dim * 4)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.hidden_dim)(x)
        x = nn.relu(x)
        return x


# Initialize the train state with the same parameters as the single-device
# model.
manual_pipeline_model = ManualPipelineMLPModel(hidden_dim=dim)
manual_pipeline_state = TrainState.create(apply_fn=manual_pipeline_model.apply,
                                          params=copy.deepcopy(params),
                                          tx=tx)


# Define the training step.
# We use the "alpa.PipeshardParallel" option to let alpa use both
# pipeline parallelism and shard parallelism. To make pipeline parallelism
# efficient, we need to fill the pipeline with many micro batches,
# so a `num_micro_batches` should be specified.
@alpa.parallelize(method=alpa.PipeshardParallel(num_micro_batches=16,
                                                layer_option="manual"))
def manual_pipeline_train_step(state, batch):

    def loss_func(params):
        out = state.apply_fn(params, batch["x"])
        loss = jnp.mean((out - batch["y"])**2)
        return loss

    # We use `alpa.grad` here to separate the apply gradient stage with the
    # forward/backward stages in the pipeline. This is necessary to ensure that
    # the gradient accumulation is correct.
    grads = alpa.grad(loss_func)(state.params)
    new_state = state.apply_gradients(grads=grads)
    return new_state


manual_pipeline_actual_state = manual_pipeline_train_step(
    manual_pipeline_state, batch)
assert_allclose(expected_state.params,
                manual_pipeline_actual_state.params,
                atol=5e-3)

alpa.shutdown()

Note

In addition, Alpa supports more flexible manual assignments of pipeline parallelism strategies. In the above example, each partitioned stages will be assigned an equal number of devices to run. If you want to control the device assignment of each stage, you can use the more advanced stage_option=alpa.ManualStageOption.

Pipeline Parallelism with Automatic Assignment

Alpa also supports automatically partitioning the model into multiple pipeline stages and assign each pipeline stage a device mesh such that the total execution latency is minimized. Specifically, the automatic partitioning algorithm consists of the following steps:

  1. Layer Construction: In this step, the operators in the model are clustered into “layers” based on a graph clustering algorithm. The user needs to specify the total number of layers (i.e. clusters) as a hyperparameter.

  2. Stage Construction and Mesh Slicing: In this step, we partition the device cluster (device mesh) to multiple submeshes and assign layers to submeshes to form pipeline stages to minimize the total pipeline execution latency.

alpa.init(cluster="ray")

# Define the parallel method.
# `alpa.AutoLayerOption(layer_num=2)` means we use the auto layer construcion
# algorithm to cluster primitive operators into two layers.
# `stage_option="auto"` means we enable the auto stage construction algorithm.
method = alpa.PipeshardParallel(num_micro_batches=16,
                                layer_option=alpa.AutoLayerOption(layer_num=2),
                                stage_option="auto")


# Define the training step. The function body is the same as the above one.
@alpa.parallelize(method=method)
def auto_pipeline_train_step(state, batch):

    def loss_func(params):
        out = state.apply_fn(params, batch["x"])
        loss = jnp.mean((out - batch["y"])**2)
        return loss

    # Again, we use `alpa.grad` here to separate the apply gradient stage with
    # the forward/backward stages in the pipeline.
    grads = alpa.grad(loss_func)(state.params)
    new_state = state.apply_gradients(grads=grads)
    return new_state


# In the first call, alpa triggers the compilation.
# The compilation first profiles several costs and solves an optimization
# problem to get the optimal pipeline assignments.
auto_pipeline_actual_state = auto_pipeline_train_step(state, batch)
assert_allclose(expected_state.params,
                auto_pipeline_actual_state.params,
                atol=5e-3)

alpa.shutdown()

Out:

-------------------- Automatic stage clustering --------------------
submesh_choices: ((1, 1), (1, 2), (1, 4), (1, 8))
- Profiling for submesh 3 (1, 8):
- Generate all stage infos (Jaxpr -> HLO)
- Compile all stages
- Profile all stages
result[(0, 1, 3, 0), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.252 GB, invar_size=0.250 GB, outvar_size=0.002 GB, temp_buffer_size=0.000 GB, available_memory=13.664 GB)
result[(0, 1, 3, 0), 1] = ModuleProfileResult(compute_cost=0.003, peak_memory=0.689 GB, invar_size=0.439 GB, outvar_size=0.250 GB, temp_buffer_size=0.250 GB, available_memory=13.664 GB)
result[(0, 1, 3, 1), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.127 GB, invar_size=0.125 GB, outvar_size=0.002 GB, temp_buffer_size=0.000 GB, available_memory=13.664 GB)
result[(0, 1, 3, 1), 1] = ModuleProfileResult(compute_cost=0.002, peak_memory=0.346 GB, invar_size=0.221 GB, outvar_size=0.125 GB, temp_buffer_size=0.125 GB, available_memory=13.664 GB)
result[(0, 1, 3, 2), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.066 GB, invar_size=0.063 GB, outvar_size=0.002 GB, temp_buffer_size=0.001 GB, available_memory=13.664 GB)
result[(0, 1, 3, 2), 1] = ModuleProfileResult(compute_cost=0.002, peak_memory=0.174 GB, invar_size=0.112 GB, outvar_size=0.063 GB, temp_buffer_size=0.063 GB, available_memory=13.664 GB)
result[(0, 1, 3, 3), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.036 GB, invar_size=0.032 GB, outvar_size=0.003 GB, temp_buffer_size=0.001 GB, available_memory=13.664 GB)
result[(0, 1, 3, 3), 1] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.060 GB, invar_size=0.058 GB, outvar_size=0.031 GB, temp_buffer_size=0.002 GB, available_memory=13.664 GB)
result[(0, 1, 3, 4), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.036 GB, invar_size=0.032 GB, outvar_size=0.003 GB, temp_buffer_size=0.001 GB, available_memory=13.664 GB)
result[(0, 1, 3, 4), 1] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.060 GB, invar_size=0.058 GB, outvar_size=0.031 GB, temp_buffer_size=0.002 GB, available_memory=13.664 GB)
Profiling for submesh 3 (1, 8) takes 38.71 seconds
--------------------------------------------------
- Profiling for submesh 2 (1, 4):
- Generate all stage infos (Jaxpr -> HLO)
- Compile all stages
- Profile all stages
result[(0, 0, 2, 0), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.127 GB, invar_size=0.125 GB, outvar_size=0.002 GB, temp_buffer_size=0.000 GB, available_memory=13.664 GB)
result[(0, 0, 2, 1), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.065 GB, invar_size=0.063 GB, outvar_size=0.002 GB, temp_buffer_size=0.000 GB, available_memory=13.664 GB)
result[(0, 0, 2, 1), 1] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.159 GB, invar_size=0.096 GB, outvar_size=0.063 GB, temp_buffer_size=0.063 GB, available_memory=13.664 GB)
result[(0, 0, 2, 0), 1] = ModuleProfileResult(compute_cost=0.002, peak_memory=0.314 GB, invar_size=0.189 GB, outvar_size=0.125 GB, temp_buffer_size=0.125 GB, available_memory=13.664 GB)
result[(0, 0, 2, 2), 1] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.052 GB, invar_size=0.050 GB, outvar_size=0.031 GB, temp_buffer_size=0.002 GB, available_memory=13.664 GB)
result[(0, 0, 2, 3), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.035 GB, invar_size=0.032 GB, outvar_size=0.002 GB, temp_buffer_size=0.000 GB, available_memory=13.664 GB)
result[(0, 0, 2, 2), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.035 GB, invar_size=0.032 GB, outvar_size=0.002 GB, temp_buffer_size=0.000 GB, available_memory=13.664 GB)
result[(0, 0, 2, 3), 1] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.052 GB, invar_size=0.050 GB, outvar_size=0.031 GB, temp_buffer_size=0.002 GB, available_memory=13.664 GB)
result[(0, 1, 2, 0), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.254 GB, invar_size=0.251 GB, outvar_size=0.003 GB, temp_buffer_size=0.000 GB, available_memory=13.664 GB)
result[(0, 1, 2, 1), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.130 GB, invar_size=0.126 GB, outvar_size=0.003 GB, temp_buffer_size=0.001 GB, available_memory=13.664 GB)
result[(0, 1, 2, 1), 1] = ModuleProfileResult(compute_cost=0.002, peak_memory=0.348 GB, invar_size=0.223 GB, outvar_size=0.125 GB, temp_buffer_size=0.125 GB, available_memory=13.664 GB)
result[(0, 1, 2, 2), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.069 GB, invar_size=0.064 GB, outvar_size=0.004 GB, temp_buffer_size=0.001 GB, available_memory=13.664 GB)
result[(0, 1, 2, 0), 1] = ModuleProfileResult(compute_cost=0.003, peak_memory=0.691 GB, invar_size=0.441 GB, outvar_size=0.250 GB, temp_buffer_size=0.250 GB, available_memory=13.664 GB)
result[(0, 1, 2, 3), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.069 GB, invar_size=0.064 GB, outvar_size=0.004 GB, temp_buffer_size=0.001 GB, available_memory=13.664 GB)
result[(0, 1, 2, 2), 1] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.116 GB, invar_size=0.114 GB, outvar_size=0.063 GB, temp_buffer_size=0.002 GB, available_memory=13.664 GB)
result[(1, 1, 2, 0), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.127 GB, invar_size=0.126 GB, outvar_size=0.002 GB, temp_buffer_size=0.000 GB, available_memory=13.664 GB)
result[(0, 1, 2, 3), 1] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.116 GB, invar_size=0.114 GB, outvar_size=0.063 GB, temp_buffer_size=0.002 GB, available_memory=13.664 GB)
result[(1, 1, 2, 1), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.066 GB, invar_size=0.063 GB, outvar_size=0.002 GB, temp_buffer_size=0.001 GB, available_memory=13.664 GB)
result[(1, 1, 2, 1), 1] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.190 GB, invar_size=0.127 GB, outvar_size=0.063 GB, temp_buffer_size=0.063 GB, available_memory=13.664 GB)
result[(1, 1, 2, 2), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.035 GB, invar_size=0.032 GB, outvar_size=0.002 GB, temp_buffer_size=0.001 GB, available_memory=13.664 GB)
result[(1, 1, 2, 2), 1] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.067 GB, invar_size=0.065 GB, outvar_size=0.032 GB, temp_buffer_size=0.001 GB, available_memory=13.664 GB)
result[(1, 1, 2, 3), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.035 GB, invar_size=0.032 GB, outvar_size=0.002 GB, temp_buffer_size=0.001 GB, available_memory=13.664 GB)
result[(1, 1, 2, 3), 1] = ModuleProfileResult(compute_cost=0.002, peak_memory=0.067 GB, invar_size=0.065 GB, outvar_size=0.032 GB, temp_buffer_size=0.001 GB, available_memory=13.664 GB)
result[(1, 1, 2, 0), 1] = ModuleProfileResult(compute_cost=0.002, peak_memory=0.377 GB, invar_size=0.252 GB, outvar_size=0.125 GB, temp_buffer_size=0.125 GB, available_memory=13.664 GB)
Profiling for submesh 2 (1, 4) takes 29.73 seconds
--------------------------------------------------
- Profiling for submesh 1 (1, 2):
- Generate all stage infos (Jaxpr -> HLO)
- Compile all stages
- Profile all stages
result[(0, 0, 1, 0), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.129 GB, invar_size=0.126 GB, outvar_size=0.003 GB, temp_buffer_size=0.000 GB, available_memory=13.664 GB)
result[(0, 0, 1, 1), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.067 GB, invar_size=0.063 GB, outvar_size=0.004 GB, temp_buffer_size=0.000 GB, available_memory=13.664 GB)
result[(0, 0, 1, 1), 1] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.102 GB, invar_size=0.098 GB, outvar_size=0.063 GB, temp_buffer_size=0.004 GB, available_memory=13.664 GB)
result[(0, 0, 1, 2), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.067 GB, invar_size=0.063 GB, outvar_size=0.004 GB, temp_buffer_size=0.000 GB, available_memory=13.664 GB)
result[(0, 1, 1, 0), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.257 GB, invar_size=0.251 GB, outvar_size=0.006 GB, temp_buffer_size=0.000 GB, available_memory=13.664 GB)
result[(0, 0, 1, 2), 1] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.102 GB, invar_size=0.098 GB, outvar_size=0.063 GB, temp_buffer_size=0.004 GB, available_memory=13.664 GB)
result[(0, 1, 1, 1), 1] = ModuleProfileResult(compute_cost=0.002, peak_memory=0.230 GB, invar_size=0.226 GB, outvar_size=0.125 GB, temp_buffer_size=0.004 GB, available_memory=13.664 GB)
result[(0, 1, 1, 2), 0] = ModuleProfileResult(compute_cost=0.002, peak_memory=0.134 GB, invar_size=0.127 GB, outvar_size=0.007 GB, temp_buffer_size=0.001 GB, available_memory=13.664 GB)
result[(0, 1, 1, 1), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.134 GB, invar_size=0.127 GB, outvar_size=0.007 GB, temp_buffer_size=0.001 GB, available_memory=13.664 GB)
result[(0, 0, 1, 0), 1] = ModuleProfileResult(compute_cost=0.002, peak_memory=0.316 GB, invar_size=0.191 GB, outvar_size=0.125 GB, temp_buffer_size=0.125 GB, available_memory=13.664 GB)
result[(1, 1, 1, 0), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.129 GB, invar_size=0.126 GB, outvar_size=0.003 GB, temp_buffer_size=0.000 GB, available_memory=13.664 GB)
result[(0, 1, 1, 2), 1] = ModuleProfileResult(compute_cost=0.002, peak_memory=0.230 GB, invar_size=0.226 GB, outvar_size=0.125 GB, temp_buffer_size=0.004 GB, available_memory=13.664 GB)
result[(1, 1, 1, 1), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.068 GB, invar_size=0.064 GB, outvar_size=0.003 GB, temp_buffer_size=0.001 GB, available_memory=13.664 GB)
result[(1, 1, 1, 1), 1] = ModuleProfileResult(compute_cost=0.002, peak_memory=0.134 GB, invar_size=0.129 GB, outvar_size=0.063 GB, temp_buffer_size=0.004 GB, available_memory=13.664 GB)
result[(1, 1, 1, 2), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.068 GB, invar_size=0.064 GB, outvar_size=0.003 GB, temp_buffer_size=0.001 GB, available_memory=13.664 GB)
result[(1, 1, 1, 2), 1] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.134 GB, invar_size=0.129 GB, outvar_size=0.063 GB, temp_buffer_size=0.004 GB, available_memory=13.664 GB)
result[(1, 1, 1, 0), 1] = ModuleProfileResult(compute_cost=0.002, peak_memory=0.379 GB, invar_size=0.254 GB, outvar_size=0.126 GB, temp_buffer_size=0.125 GB, available_memory=13.664 GB)
result[(0, 1, 1, 0), 1] = ModuleProfileResult(compute_cost=0.003, peak_memory=0.694 GB, invar_size=0.444 GB, outvar_size=0.250 GB, temp_buffer_size=0.250 GB, available_memory=13.664 GB)
Profiling for submesh 1 (1, 2) takes 25.08 seconds
--------------------------------------------------
- Profiling for submesh 0 (1, 1):
- Generate all stage infos (Jaxpr -> HLO)
- Compile all stages
- Profile all stages
result[(0, 1, 0, 1), 0] = ModuleProfileResult(compute_cost=0.006, peak_memory=0.264 GB, invar_size=0.252 GB, outvar_size=0.012 GB, temp_buffer_size=0.000 GB, available_memory=13.664 GB)
result[(1, 1, 0, 0), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.133 GB, invar_size=0.127 GB, outvar_size=0.006 GB, temp_buffer_size=0.000 GB, available_memory=13.664 GB)
result[(0, 0, 0, 1), 1] = ModuleProfileResult(compute_cost=0.003, peak_memory=0.202 GB, invar_size=0.195 GB, outvar_size=0.125 GB, temp_buffer_size=0.008 GB, available_memory=13.664 GB)
result[(1, 1, 0, 1), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.133 GB, invar_size=0.127 GB, outvar_size=0.006 GB, temp_buffer_size=0.000 GB, available_memory=13.664 GB)
result[(1, 1, 0, 1), 1] = ModuleProfileResult(compute_cost=0.002, peak_memory=0.266 GB, invar_size=0.257 GB, outvar_size=0.126 GB, temp_buffer_size=0.008 GB, available_memory=13.664 GB)
result[(1, 1, 0, 0), 1] = ModuleProfileResult(compute_cost=0.002, peak_memory=0.266 GB, invar_size=0.257 GB, outvar_size=0.126 GB, temp_buffer_size=0.008 GB, available_memory=13.664 GB)
result[(0, 1, 0, 0), 0] = ModuleProfileResult(compute_cost=0.002, peak_memory=0.264 GB, invar_size=0.252 GB, outvar_size=0.012 GB, temp_buffer_size=0.000 GB, available_memory=13.664 GB)
result[(0, 0, 0, 1), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.132 GB, invar_size=0.126 GB, outvar_size=0.006 GB, temp_buffer_size=0.000 GB, available_memory=13.664 GB)
result[(0, 0, 0, 0), 0] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.132 GB, invar_size=0.126 GB, outvar_size=0.006 GB, temp_buffer_size=0.000 GB, available_memory=13.664 GB)
result[(0, 0, 0, 0), 1] = ModuleProfileResult(compute_cost=0.001, peak_memory=0.202 GB, invar_size=0.195 GB, outvar_size=0.125 GB, temp_buffer_size=0.008 GB, available_memory=13.664 GB)
result[(0, 1, 0, 1), 1] = ModuleProfileResult(compute_cost=0.003, peak_memory=0.459 GB, invar_size=0.451 GB, outvar_size=0.250 GB, temp_buffer_size=0.008 GB, available_memory=13.664 GB)
result[(0, 1, 0, 0), 1] = ModuleProfileResult(compute_cost=0.003, peak_memory=0.459 GB, invar_size=0.451 GB, outvar_size=0.250 GB, temp_buffer_size=0.008 GB, available_memory=13.664 GB)
Profiling for submesh 0 (1, 1) takes 27.47 seconds
--------------------------------------------------
Profile result saved to: profile-results-2023-01-14-09-53-30.npy
----------------------------------------------------------------------
Result forward_stage_layer_ids: [[0], [1]]
Result mesh_shapes: [(1, 4), (1, 4)]
Result logical_mesh_shapes: [(4, 1), (1, 4)]
Result autosharding_option_dicts: [{}, {'force_batch_dim_to_mesh_dim': 0}]

Total running time of the script: ( 3 minutes 9.712 seconds)

Gallery generated by Sphinx-Gallery