forked from apple/axlearn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Dockerfile
117 lines (88 loc) · 4.42 KB
/
Dockerfile
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# syntax=docker/dockerfile:1
ARG TARGET=base
ARG BASE_IMAGE=python:3.10-slim
FROM ${BASE_IMAGE} AS base
RUN apt-get update
RUN apt-get install -y apt-transport-https ca-certificates gnupg curl gcc g++
# Install git.
RUN apt-get install -y git
# Install gcloud. https://cloud.google.com/sdk/docs/install
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && \
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | gpg --dearmor -o /usr/share/keyrings/cloud.google.gpg && \
apt-get update -y && apt-get install google-cloud-cli -y
# Install screen and other utils for launch script.
RUN apt-get install -y jq screen ca-certificates
# Setup.
RUN mkdir -p /root
WORKDIR /root
# Introduce the minimum set of files for install.
COPY README.md README.md
COPY pyproject.toml pyproject.toml
RUN mkdir axlearn && touch axlearn/__init__.py
# Setup venv to suppress pip warnings.
ENV VIRTUAL_ENV=/opt/venv
RUN python -m venv $VIRTUAL_ENV
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# Install dependencies.
RUN pip install flit
RUN pip install --upgrade pip
################################################################################
# CI container spec. #
################################################################################
# Leverage multi-stage build for unit tests.
FROM base AS ci
# TODO(markblee): Remove gcp,vertexai_tensorboard from CI.
RUN pip install .[core,dev,grain,gcp,vertexai_tensorboard]
COPY . .
# Defaults to an empty string, i.e. run pytest against all files.
ARG PYTEST_FILES=''
# Defaults to empty string, i.e. do NOT skip precommit
ARG SKIP_PRECOMMIT=''
# `exit 1` fails the build.
RUN ./run_tests.sh $SKIP_PRECOMMIT "${PYTEST_FILES}"
################################################################################
# Bastion container spec. #
################################################################################
FROM base AS bastion
# TODO(markblee): Consider copying large directories separately, to cache more aggressively.
# TODO(markblee): Is there a way to skip the "production" deps?
COPY . /root/
RUN pip install .[core,gcp,vertexai_tensorboard]
################################################################################
# Dataflow container spec. #
################################################################################
FROM base AS dataflow
# Beam workers default to creating a new virtual environment on startup. Instead, we want them to
# pickup the venv setup above. An alternative is to install into the global environment.
ENV RUN_PYTHON_SDK_IN_DEFAULT_ENVIRONMENT=1
RUN pip install .[core,gcp,dataflow]
COPY . .
# Dataflow workers can't start properly if the entrypoint is not set
# See: https://cloud.google.com/dataflow/docs/guides/build-container-image#use_a_custom_base_image
COPY --from=apache/beam_python3.10_sdk:2.52.0 /opt/apache/beam /opt/apache/beam
ENTRYPOINT ["/opt/apache/beam/boot"]
################################################################################
# TPU container spec. #
################################################################################
FROM base AS tpu
ARG EXTRAS=
RUN apt-get install -y google-perftools
ENV PIP_FIND_LINKS=https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Ensure we install the TPU version, even if building locally.
# Jax will fallback to CPU when run on a machine without TPU.
RUN pip install .[core,tpu]
RUN if [ -n "$EXTRAS" ]; then pip install .[$EXTRAS]; fi
COPY . .
################################################################################
# GPU container spec. #
################################################################################
FROM base AS gpu
RUN apt-get install -y google-perftools
# TODO(markblee): Support extras.
ENV PIP_FIND_LINKS=https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
RUN pip install .[core,gpu]
COPY . .
################################################################################
# Final target spec. #
################################################################################
FROM ${TARGET} AS final