feat(docker): enhance CUDA support in Dockerfile and pyproject.toml

- Updated Dockerfile to conditionally install PyTorch with CPU or CUDA support based on build arguments.
- Added optional dependencies for CPU and CUDA versions of PyTorch in pyproject.toml.
- Configured uv.lock to manage dependencies and conflicts between CPU and CUDA installations.
This commit is contained in:
Anish Sarkar 2026-06-05 21:46:09 +05:30
parent da8cb32e77
commit 6972356c86
3 changed files with 731 additions and 155 deletions

View file

@ -61,15 +61,25 @@ COPY pyproject.toml uv.lock ./
# Exporting the lock to requirements.txt and feeding it to `uv pip install`
# pins every transitive package to the exact version captured in uv.lock.
#
# Note on torch/CUDA: we do NOT install torch from a separate cu* index here.
# PyPI's torch wheels for Linux x86_64 already ship CUDA-enabled and pull
# nvidia-cudnn-cu13, nvidia-nccl-cu13, triton, etc. as install deps (all
# captured in uv.lock). If a specific CUDA version is needed, wire it through
# [tool.uv.sources] in pyproject.toml so the lock stays the source of truth.
# Note on torch/CUDA: the export must always select either the cpu or CUDA
# extra declared in pyproject.toml. A no-extra export would resolve torch from
# PyPI on Linux, which currently pulls CUDA-enabled wheels and nvidia-* deps.
# Keep CUDA version selection in [tool.uv.sources] so uv.lock remains the
# source of truth. The install step also needs the matching PyTorch index,
# because requirements.txt preserves the +cpu/+cu wheel pins but not uv's
# package source metadata.
ARG USE_CUDA=false
ARG CUDA_EXTRA=cu128
RUN pip install --no-cache-dir uv && \
if [ "$USE_CUDA" = "true" ]; then EXTRA="$CUDA_EXTRA"; else EXTRA="cpu"; fi && \
TORCH_INDEX="https://download.pytorch.org/whl/${EXTRA}" && \
uv export --frozen --no-dev --no-hashes --no-emit-project \
--extra "$EXTRA" \
--format requirements-txt -o /tmp/requirements.txt && \
uv pip install --system --no-cache-dir -r /tmp/requirements.txt && \
uv pip install --system --no-cache-dir \
--index "$TORCH_INDEX" \
--index-strategy unsafe-best-match \
-r /tmp/requirements.txt && \
rm /tmp/requirements.txt

View file

@ -92,6 +92,10 @@ dependencies = [
"croniter>=2.0.0",
]
[project.optional-dependencies]
cpu = ["torch==2.11.0", "torchvision==0.26.0"]
cu128 = ["torch==2.11.0", "torchvision==0.26.0"]
[dependency-groups]
dev = [
"ruff>=0.12.5",
@ -101,6 +105,29 @@ dev = [
"httpx>=0.28.1",
]
[tool.uv]
conflicts = [[{ extra = "cpu" }, { extra = "cu128" }]]
[tool.uv.sources]
torch = [
{ index = "pytorch-cpu", extra = "cpu", marker = "sys_platform == 'linux'" },
{ index = "pytorch-cu128", extra = "cu128", marker = "sys_platform == 'linux'" },
]
torchvision = [
{ index = "pytorch-cpu", extra = "cpu", marker = "sys_platform == 'linux'" },
{ index = "pytorch-cu128", extra = "cu128", marker = "sys_platform == 'linux'" },
]
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true
[tool.ruff]
# Exclude a variety of commonly ignored directories.
exclude = [

File diff suppressed because it is too large Load diff