Install Alpa
This page provides instructions to install alpa from Python wheels or from source. The minimum supported python version is 3.7.
Prerequisites
Regardless of installing from wheels or from source, there are a few prerequisite packages:
CUDA toolkit:
Update pip version and install cupy:
# Update pip pip3 install --upgrade pip # Install cupy pip3 install cupy-cuda11xThen, check whether your system already has NCCL installed.
python3 -c "from cupy.cuda import nccl"
If it prints nothing, then NCCL has already been installed. Otherwise, follow the printed instructions to install NCCL.
Methods
Choose one of the methods below.
Method 1: Install from Python Wheels
Alpa provides wheels for the following CUDA (cuDNN) and Python versions:
CUDA (cuDNN): 11.1 (8.0.5), 11.2 (8.1.0), 11.3 (8.2.0)
Python: 3.7, 3.8, 3.9
If you need to use other CUDA, cuDNN, or Python versions, please follow the next section to install from source.
Install Alpa python package.
pip3 install alpa
Install Alpa-modified Jaxlib. Make sure that the jaxlib version corresponds to the version of the existing CUDA and cuDNN installation you want to use. You can specify a particular CUDA and cuDNN version for jaxlib explicitly via:
pip3 install jaxlib==0.3.22+cuda{cuda_version}.cudnn{cudnn_version} -f https://alpa-projects.github.io/wheels.htmlFor example, to install the wheel compatible with CUDA >= 11.1 and cuDNN >= 8.0.5, use the following command:
pip3 install jaxlib==0.3.22+cuda111.cudnn805 -f https://alpa-projects.github.io/wheels.htmlYou can see all available wheel versions we provided at our PyPI index.
Note
As of now, Alpa modified the original jaxlib at the version jaxlib==0.3.22
. Alpa regularly rebases the official jaxlib repository to catch up with the upstream.
Method 2: Install from Source
Clone repos
git clone --recursive https://github.com/alpa-projects/alpa.git
Install Alpa python package.
cd alpa pip3 install -e ".[dev]" # Note that the suffix `[dev]` is required to build custom modules.
Build and install Alpa-modified Jaxlib. The Jaxlib contains c++ code of Alpa.
cd build_jaxlib python3 build/build.py --enable_cuda --dev_install --bazel_options=--override_repository=org_tensorflow=$(pwd)/../third_party/tensorflow-alpa cd dist pip3 install -e .
Note
Building the latest Alpa-modified jaxlib requires new C++17 standards. It is known that some compiler versions such as gcc==7.3
or gcc==9.4
cannot correctly compile the jaxlib code.
See this thread about the know issues.
If you meet compilation errors, please install our recommended gcc version gcc==7.5
; newer gcc versions might also work.
Then please clean the bazel cache (rm -rf ~/.cache/bazel
) and try to build jaxlib again.
Note
All installations are in development mode, so you can modify python code and it will take effect immediately. To modify c++ code in tensorflow, you only need to run the command below from step 3 to recompile jaxlib:
python3 build/build.py --enable_cuda --dev_install --bazel_options=--override_repository=org_tensorflow=$(pwd)/../third_party/tensorflow-alpa
Note
Alpa python package and Alpa-modified Jaxlib are two separate libraries. If you only want to develop the python source code, you can install Alpa python package from source and install Alpa-modified Jaxlib from wheels.
Check Installation
You can check the installation by running the following commands.
ray start --head
python3 -m alpa.test_install
[Optional] PyTorch Frontend
While Alpa is mainly designed for Jax, Alpa also provides an experimental PyTorch frontend. Alpa supports PyTorch models that meet the following requirements:
No input-dependent control flow
No weight sharing
To enable Alpa for PyTorch, install the following dependencies:
# Install torch and torchdistx pip3 uninstall -y torch torchdistx pip install --extra-index-url https://download.pytorch.org/whl/cpu torch==1.12 torchdistx # Build functorch from source git clone https://github.com/pytorch/functorch cd functorch/ git checkout 76976db8412b60d322c680a5822116ba6f2f762a python3 setup.py install
Please look at tests/torch_frontend/test_simple.py
for usage examples.
Troubleshooting
Unhandled Cuda Error
If you see errors like cupy_backends.cuda.libs.nccl.NcclError: NCCL_ERROR_UNHANDLED_CUDA_ERROR: unhandled cuda error
, it is mainly due to the compatibility issues between CUDA, NCCL, and GPU driver versions. Please double check these versions and see Issue #496 for more details.
Using Alpa on Slurm
Since Alpa relies on Ray to manage the cluster nodes, Alpa can run on a Slurm cluster as long as Ray can run on it. If you have trouble running Alpa on a Slurm cluster, we recommend to follow this guide to setup Ray on Slurm and make sure simple Ray examples can run without any problem, then move forward to install and run Alpa in the same environment.
Common issues of running Alpa on Slurm include:
The Slurm cluster has installed additional networking proxies, so XLA client connections time out. Example errors can be found in this thread. The slurm cluster users might need to check and fix those proxies on their slurm cluster and make sure processes spawned by Alpa can see each other.
When launching a Slurm job using
SRUN
, the users do not request enough CPU threads or GPU resources for Ray to spawn many actors on Slurm. The users need to adjust the value for the argument--cpus-per-task
passed toSRUN
when launching Alpa. See Slurm documentation for more information.
You might also find the discussion under Issue #452 helpful.
Jaxlib, Jax, Flax Version Problems
Alpa is only tested against specific versions of Jax and Flax.
The recommended Jax and Flax versions are specified by install_require_list
in setup.py .
(You can checkout the file to specific version tag if you are not using the latest HEAD.)
If you see version errors like below
>>> import alpa
......
RuntimeError: jaxlib version 0.3.7 is newer than and incompatible with jax version 0.3.5. Please update your jax and/or jaxlib packages
Make sure your Jax, Flax and Optax/Chex versions are compatible with the versions specified in Alpa’s setup.py
.
Make sure you re-install Alpa-modified Jaxlib by either using our prebuilt wheels or Install from Source to overwrite the default Jaxlib.
Numpy Version Problems
If you start with a clean Python virtual environment and have followed the procedures in this guide strictly, you should not see problems about Numpy versions.
However, sometimes due to the installation of other Python packages, another version of numpy might be silently installed before compiling jaxlib, and you might see numpy version errors similar to the following one when launching Alpa after installing from source:
>>> python3 tests/test_install.py
......
RuntimeError: module compiled against API version 0xf but this version of numpy is 0xd
ImportError: numpy.core._multiarray_umath failed to import
ImportError: numpy.core.umath failed to import
2022-05-20 21:57:35.710782: F external/org_tensorflow/tensorflow/compiler/xla/python/xla.cc:83] Check failed: tensorflow::RegisterNumpyBfloat16()
Aborted (core dumped)
This is because you have used a higher version of numpy when compiling jaxlib, but later used a lower version of numpy to run Alpa.
To address the problem, please first downgrade the numpy in your Python environment to numpy==1.20
via pip install numpy==1.20
,
then follow the procedures in install from source to rebuild and reinstall jaxlib.
Optionally, you can switch back to use the higher version of numpy (numpy>=1.20
) to run Alpa and your other applications, thanks to numpy’s backward compatibility.
See Issue#461 for more discussion.
Tests Hang with no Errors on Multi-GPU Nodes
This could be an indication that IO virtualization (VT-d, or IOMMU) is interfereing with the NCCL library. On multi-gpu systems, PCI point-to-point traffic can be redirected to the CPU by these systems causing performance reductions or programs to hang. These settings can typically be disabled from the BIOS, or sometimes from the OS. You can find more information on Nividia’s NCCL troubleshooting guide here. Note that disabling IO virtualization can introduce security vulnerabilities, with peripherals having read/write access to DRAM through the DMA (Direct Memory Access) protocol.