Last active
February 5, 2026 14:28
-
-
Save fblissjr/38c39bdf64a6c10447152845d3af4d53 to your computer and use it in GitHub Desktop.
rtx 4090 sm89 build script for various torch source compiles
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/bin/bash | |
| set -e | |
| # --- Configuration --- | |
| export TORCH_CUDA_ARCH_LIST="8.9" | |
| export USE_CPP=1 | |
| export MAX_JOBS=$(nproc) | |
| # Force CUDA 13.0 path if it exists, otherwise use default | |
| if [ -d "/usr/local/cuda-13.0" ]; then | |
| export CUDA_HOME="/usr/local/cuda-13.0" | |
| export PATH=$CUDA_HOME/bin:$PATH | |
| echo "-> Forced CUDA_HOME to $CUDA_HOME" | |
| fi | |
| echo "=== TorchAO Builder for RTX 4090 (Fixed) ===" | |
| # --- 1. Prompts --- | |
| echo "Select PyTorch Version:" | |
| echo " 1) Stable (2.10 + cu130)" | |
| echo " 2) Nightly (Preview + cu130)" | |
| read -p "Enter choice [1]: " pt_choice | |
| pt_choice=${pt_choice:-1} | |
| if [ "$pt_choice" -eq 1 ]; then | |
| # Note: Adjusting to strict download link to ensure cu130 | |
| PT_INDEX_URL="https://download.pytorch.org/whl/cu130%22 | |
| PT_FLAGS="" | |
| else | |
| PT_INDEX_URL="https://download.pytorch.org/whl/nightly/cu130%22 | |
| PT_FLAGS="--pre" | |
| fi | |
| # --- 2. Clean Slate --- | |
| # We MUST delete the old venv to fix the 'Audited 2 packages' issue | |
| echo -e "\n\033[0;33m[!] Removing old build environment to force fresh install...\033[0m" | |
| rm -rf .build_venv build dist torchao.egg-info | |
| echo "Creating fresh venv..." | |
| uv venv .build_venv | |
| source .build_venv/bin/activate | |
| # --- 3. Install Dependencies --- | |
| echo "Installing Build Dependencies..." | |
| uv pip install ninja cmake setuptools wheel packaging numpy | |
| echo "Installing PyTorch from $PT_INDEX_URL..." | |
| # --no-cache-dir ensures we don't grab a cached cu128 wheel by mistake | |
| uv pip install $PT_FLAGS torch torchvision --index-url "$PT_INDEX_URL" --extra-index-url https://pypi.org/simple --no-cache-dir | |
| # --- 4. VERIFY & PATCH --- | |
| echo "Verifying PyTorch CUDA version..." | |
| PT_CUDA=$(python -c "import torch; print(torch.version.cuda)") | |
| SYS_CUDA=$(nvcc --version | grep "release" | awk '{print $5}' | cut -d',' -f1) | |
| echo "System CUDA: $SYS_CUDA" | |
| echo "PyTorch CUDA: $PT_CUDA" | |
| if [ "$PT_CUDA" != "$SYS_CUDA" ]; then | |
| echo -e "\n\033[0;31m[WARNING] Version mismatch detected! ($SYS_CUDA vs $PT_CUDA)\033[0m" | |
| echo "Attempting to patch PyTorch to allow compilation..." | |
| # Locate cpp_extension.py | |
| CPP_EXT_PATH=$(python -c "import torch.utils.cpp_extension as c; print(c.__file__)") | |
| # Comment out the error raising line | |
| # This matches the specific line in your traceback: raise RuntimeError(CUDA_MISMATCH_MESSAGE...) | |
| sed -i 's/raise RuntimeError(CUDA_MISMATCH_MESSAGE/print(" [PATCHED] Ignoring CUDA mismatch: " + CUDA_MISMATCH_MESSAGE/g' "$CPP_EXT_PATH" | |
| echo "Patch applied. Crossing fingers." | |
| fi | |
| # --- 5. Build --- | |
| echo "Starting Optimized Compilation..." | |
| uv pip install build | |
| python -m build --wheel --no-isolation | |
| echo "=================================" | |
| ls -lh dist/*.whl | |
| echo "Done. Install with: uv pip install dist/*.whl" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment