Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] Add Support for NVIDIA H100 GPU (Compute Capability 9.0) #5160

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions op_builder/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def set_cuda_arch_list(cuda_dir):
"2. Volta (compute capability 7.0)\n"
"3. Turing (compute capability 7.5),\n"
"4. Ampere (compute capability 8.0, 8.6)if the CUDA version is >= 11.0\n"
"5. Hopper (compute capability 9.0) if the CUDA version is >= 11.5\n"
"\nIf you wish to cross-compile for a single specific architecture,\n"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n'
)
Expand All @@ -177,12 +178,14 @@ def set_cuda_arch_list(cuda_dir):

arch_list = ["6.0", "6.1", "6.2", "7.0", "7.5"]

if int(bare_metal_major) == 11:
if int(bare_metal_major) >= 11:
if int(bare_metal_minor) == 0:
arch_list.append("8.0")
else:
arch_list.append("8.0")
arch_list.append("8.6")
arch_list.extend(["8.0", "8.6"])
if int(bare_metal_major) >= 11 and int(bare_metal_minor) >= 5 or int(bare_metal_major) >= 12:
arch_list.append("9.0")


arch_list_str = ";".join(arch_list)
os.environ["TORCH_CUDA_ARCH_LIST"] = arch_list_str
Expand Down
Loading