Install Alpa

This page provides instructions to install alpa from Python wheels or from source. The minimum supported python version is 3.7.

Prerequisites

Regardless of installing from wheels or from source, there are a few prerequisite packages:

  1. CUDA toolkit:

Follow the official guides to install CUDA and cuDNN. Alpa requires CUDA >= 11.1 and cuDNN >= 8.0.5.

  1. Update pip version and install cupy:

# Update pip
pip3 install --upgrade pip

# Install cupy
pip3 install cupy-cuda11x

Then, check whether your system already has NCCL installed.

python3 -c "from cupy.cuda import nccl"

If it prints nothing, then NCCL has already been installed. Otherwise, follow the printed instructions to install NCCL.

Methods

Choose one of the methods below.

Method 1: Install from Python Wheels

Alpa provides wheels for the following CUDA (cuDNN) and Python versions:

  • CUDA (cuDNN): 11.1 (8.0.5), 11.2 (8.1.0), 11.3 (8.2.0)

  • Python: 3.7, 3.8, 3.9

If you need to use other CUDA, cuDNN, or Python versions, please follow the next section to install from source.

  1. Install Alpa python package.

pip3 install alpa
  1. Install Alpa-modified Jaxlib. Make sure that the jaxlib version corresponds to the version of the existing CUDA and cuDNN installation you want to use. You can specify a particular CUDA and cuDNN version for jaxlib explicitly via:

pip3 install jaxlib==0.3.22+cuda{cuda_version}.cudnn{cudnn_version} -f https://alpa-projects.github.io/wheels.html

For example, to install the wheel compatible with CUDA >= 11.1 and cuDNN >= 8.0.5, use the following command:

pip3 install jaxlib==0.3.22+cuda111.cudnn805 -f https://alpa-projects.github.io/wheels.html

You can see all available wheel versions we provided at our PyPI index.

Note

As of now, Alpa modified the original jaxlib at the version jaxlib==0.3.22. Alpa regularly rebases the official jaxlib repository to catch up with the upstream.

Method 2: Install from Source

  1. Clone repos

git clone --recursive https://github.com/alpa-projects/alpa.git
  1. Install Alpa python package.

cd alpa
pip3 install -e ".[dev]"  # Note that the suffix `[dev]` is required to build custom modules.
  1. Build and install Alpa-modified Jaxlib. The Jaxlib contains c++ code of Alpa.

cd build_jaxlib
python3 build/build.py --enable_cuda --dev_install --bazel_options=--override_repository=org_tensorflow=$(pwd)/../third_party/tensorflow-alpa
cd dist

pip3 install -e .

Note

Building the latest Alpa-modified jaxlib requires new C++17 standards. It is known that some compiler versions such as gcc==7.3 or gcc==9.4 cannot correctly compile the jaxlib code. See this thread about the know issues.

If you meet compilation errors, please install our recommended gcc version gcc==7.5; newer gcc versions might also work. Then please clean the bazel cache (rm -rf ~/.cache/bazel) and try to build jaxlib again.

Note

All installations are in development mode, so you can modify python code and it will take effect immediately. To modify c++ code in tensorflow, you only need to run the command below from step 3 to recompile jaxlib:

python3 build/build.py --enable_cuda --dev_install --bazel_options=--override_repository=org_tensorflow=$(pwd)/../third_party/tensorflow-alpa

Note

Alpa python package and Alpa-modified Jaxlib are two separate libraries. If you only want to develop the python source code, you can install Alpa python package from source and install Alpa-modified Jaxlib from wheels.

Check Installation

You can check the installation by running the following commands.

ray start --head
python3 -m alpa.test_install

[Optional] PyTorch Frontend

While Alpa is mainly designed for Jax, Alpa also provides an experimental PyTorch frontend. Alpa supports PyTorch models that meet the following requirements:

  1. No input-dependent control flow

  2. No weight sharing

To enable Alpa for PyTorch, install the following dependencies:

# Install torch and torchdistx
pip3 uninstall -y torch torchdistx
pip install --extra-index-url https://download.pytorch.org/whl/cpu torch==1.12 torchdistx

# Build functorch from source
git clone https://github.com/pytorch/functorch
cd functorch/
git checkout 76976db8412b60d322c680a5822116ba6f2f762a
python3 setup.py install

Please look at tests/torch_frontend/test_simple.py for usage examples.

Troubleshooting

Unhandled Cuda Error

If you see errors like cupy_backends.cuda.libs.nccl.NcclError: NCCL_ERROR_UNHANDLED_CUDA_ERROR: unhandled cuda error, it is mainly due to the compatibility issues between CUDA, NCCL, and GPU driver versions. Please double check these versions and see Issue #496 for more details.

Using Alpa on Slurm

Since Alpa relies on Ray to manage the cluster nodes, Alpa can run on a Slurm cluster as long as Ray can run on it. If you have trouble running Alpa on a Slurm cluster, we recommend to follow this guide to setup Ray on Slurm and make sure simple Ray examples can run without any problem, then move forward to install and run Alpa in the same environment.

Common issues of running Alpa on Slurm include:

  • The Slurm cluster has installed additional networking proxies, so XLA client connections time out. Example errors can be found in this thread. The slurm cluster users might need to check and fix those proxies on their slurm cluster and make sure processes spawned by Alpa can see each other.

  • When launching a Slurm job using SRUN, the users do not request enough CPU threads or GPU resources for Ray to spawn many actors on Slurm. The users need to adjust the value for the argument --cpus-per-task passed to SRUN when launching Alpa. See Slurm documentation for more information.

You might also find the discussion under Issue #452 helpful.

Jaxlib, Jax, Flax Version Problems

Alpa is only tested against specific versions of Jax and Flax. The recommended Jax and Flax versions are specified by install_require_list in setup.py . (You can checkout the file to specific version tag if you are not using the latest HEAD.)

If you see version errors like below

>>> import alpa
  ......
  RuntimeError: jaxlib version 0.3.7 is newer than and incompatible with jax version 0.3.5. Please update your jax and/or jaxlib packages

Make sure your Jax, Flax and Optax/Chex versions are compatible with the versions specified in Alpa’s setup.py. Make sure you re-install Alpa-modified Jaxlib by either using our prebuilt wheels or Install from Source to overwrite the default Jaxlib.

Numpy Version Problems

If you start with a clean Python virtual environment and have followed the procedures in this guide strictly, you should not see problems about Numpy versions.

However, sometimes due to the installation of other Python packages, another version of numpy might be silently installed before compiling jaxlib, and you might see numpy version errors similar to the following one when launching Alpa after installing from source:

>>> python3 tests/test_install.py
  ......
  RuntimeError: module compiled against API version 0xf but this version of numpy is 0xd
  ImportError: numpy.core._multiarray_umath failed to import
  ImportError: numpy.core.umath failed to import
  2022-05-20 21:57:35.710782: F external/org_tensorflow/tensorflow/compiler/xla/python/xla.cc:83] Check failed: tensorflow::RegisterNumpyBfloat16()
  Aborted (core dumped)

This is because you have used a higher version of numpy when compiling jaxlib, but later used a lower version of numpy to run Alpa.

To address the problem, please first downgrade the numpy in your Python environment to numpy==1.20 via pip install numpy==1.20, then follow the procedures in install from source to rebuild and reinstall jaxlib. Optionally, you can switch back to use the higher version of numpy (numpy>=1.20) to run Alpa and your other applications, thanks to numpy’s backward compatibility.

See Issue#461 for more discussion.

Tests Hang with no Errors on Multi-GPU Nodes

This could be an indication that IO virtualization (VT-d, or IOMMU) is interfereing with the NCCL library. On multi-gpu systems, PCI point-to-point traffic can be redirected to the CPU by these systems causing performance reductions or programs to hang. These settings can typically be disabled from the BIOS, or sometimes from the OS. You can find more information on Nividia’s NCCL troubleshooting guide here. Note that disabling IO virtualization can introduce security vulnerabilities, with peripherals having read/write access to DRAM through the DMA (Direct Memory Access) protocol.