File size: 1,703 Bytes
14ebd2e 2734e3f 14ebd2e 2734e3f 14ebd2e 92d800a 14ebd2e 2734e3f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
ARG CUDA_VERSION="11.8.0"
ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04"
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION as base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.9"
ARG PYTORCH="2.0.0"
ARG CUDA="cu118"
ENV PYTHON_VERSION=$PYTHON_VERSION
RUN apt-get update
RUN apt-get install -y wget git build-essential ninja-build git-lfs && rm -rf /var/lib/apt/lists/*
RUN wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh
RUN conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --no-cache-dir -U torch==${PYTORCH} torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/$CUDA
FROM base-builder AS flash-attn-builder
WORKDIR /workspace
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
RUN git clone https://github.com/HazyResearch/flash-attention.git && \
cd flash-attention && \
python3 setup.py bdist_wheel
FROM base-builder AS deepspeed-builder
WORKDIR /workspace
RUN git clone https://github.com/microsoft/DeepSpeed.git && \
cd DeepSpeed && \
MAX_CONCURRENCY=8 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_OPS=1 python3 setup.py bdist_wheel
FROM base-builder
RUN mkdir /workspace/wheels
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/dist/flash_attn-*.whl wheels
RUN pip3 install.sh wheels/deepspeed-*.whl wheels/flash_attn-*.whl
|