Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support moore threads MUSA arch #3018

Open
wants to merge 26 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
21841f4
for musa by [email protected]
hanhaowen-mt Dec 8, 2023
e7b486a
.
hanhaowen-mt Dec 11, 2023
bf6cad5
.
hanhaowen-mt Dec 11, 2023
475ed1a
.
hanhaowen-mt Dec 12, 2023
049b472
add musa in some py file, ongoing, add MUSA_install.sh for install mm…
hanhaowen-mt Jan 4, 2024
7dd271c
still working for musa
hanhaowen-mt Jan 5, 2024
b8bc909
comment upfirdn2d_op since s3000's shared memory is too small
hanhaowen-mt Jan 5, 2024
df8d613
comment carafe_backward_musa for the same reason
hanhaowen-mt Jan 5, 2024
67cbed0
comment carafe_forward_musa for the same reason
hanhaowen-mt Jan 5, 2024
2a773d4
comment chamfer_distance_forward_musa for the same reason
hanhaowen-mt Jan 5, 2024
d9a8d23
continue to port to musa
hanhaowen-mt Jan 5, 2024
9f57468
merge to origin/main
hanhaowen-mt Jan 12, 2024
9fc0d43
Revert "comment chamfer_distance_forward_musa for the same reason"
hanhaowen-mt Jan 12, 2024
50bb086
support CONDITIONAL MACRO for chamfer distance
hanhaowen-mt Jan 12, 2024
dcbff7d
Revert "comment carafe_forward_musa for the same reason"
hanhaowen-mt Jan 12, 2024
060f83c
Revert "comment carafe_backward_musa for the same reason"
hanhaowen-mt Jan 12, 2024
ea008ab
support CONDITIONAL MACRO for carafe_backward_musa and carafe_forward…
hanhaowen-mt Jan 12, 2024
c1240bb
Revert "comment upfirdn2d_op since s3000's shared memory is too small"
hanhaowen-mt Jan 12, 2024
42f4424
support CONDITIONAL MACRO for upfirdn2d
hanhaowen-mt Jan 12, 2024
7088185
Update MUSA_ ARCH macro
hanhaowen-mt Jan 12, 2024
b2953be
set musa_arch from 210 to 21 and auto set it , so we can install mmc…
hanhaowen-mt Jan 12, 2024
b5fda52
fix some bugs for get_indice_pairs_backward_musa
hanhaowen-mt Jan 12, 2024
3e944d4
fix some bugs in ut for musa
hanhaowen-mt Jan 15, 2024
768090f
support new musaExtension
hanhaowen-mt Jan 29, 2024
cfe91db
Merge branch 'main' into musa_mmcv_main
hanhaowen-mt Feb 6, 2024
ebdabe7
Merge branch 'main' into musa_mmcv_main
hanhaowen-mt Mar 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mmcv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@
from .video import *
from .visualization import *

try:
import torch
import torch_musa
except:
pass
# The following modules are not imported to this level, so mmcv may be used
# without PyTorch.
# - op
Expand Down
205 changes: 205 additions & 0 deletions mmcv/ops/bias_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,81 @@ def __delattr__(self, name: str) -> None:
has_2nd_grad=True),
}

activation_funcs_musa = {
'linear':
EasyDict(
func=lambda x, **_: x,
def_alpha=0,
def_gain=1,
musa_idx=1,
ref='',
has_2nd_grad=False),
'relu':
EasyDict(
func=lambda x, **_: torch.nn.functional.relu(x),
def_alpha=0,
def_gain=np.sqrt(2),
musa_idx=2,
ref='y',
has_2nd_grad=False),
'lrelu':
EasyDict(
func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
def_alpha=0.2,
def_gain=np.sqrt(2),
musa_idx=3,
ref='y',
has_2nd_grad=False),
'tanh':
EasyDict(
func=lambda x, **_: torch.tanh(x),
def_alpha=0,
def_gain=1,
musa_idx=4,
ref='y',
has_2nd_grad=True),
'sigmoid':
EasyDict(
func=lambda x, **_: torch.sigmoid(x),
def_alpha=0,
def_gain=1,
musa_idx=5,
ref='y',
has_2nd_grad=True),
'elu':
EasyDict(
func=lambda x, **_: torch.nn.functional.elu(x),
def_alpha=0,
def_gain=1,
musa_idx=6,
ref='y',
has_2nd_grad=True),
'selu':
EasyDict(
func=lambda x, **_: torch.nn.functional.selu(x),
def_alpha=0,
def_gain=1,
musa_idx=7,
ref='y',
has_2nd_grad=True),
'softplus':
EasyDict(
func=lambda x, **_: torch.nn.functional.softplus(x),
def_alpha=0,
def_gain=1,
musa_idx=8,
ref='y',
has_2nd_grad=True),
'swish':
EasyDict(
func=lambda x, **_: torch.sigmoid(x) * x,
def_alpha=0,
def_gain=np.sqrt(2),
musa_idx=9,
ref='x',
has_2nd_grad=True),
}

_null_tensor = torch.empty([0])


Expand Down Expand Up @@ -167,6 +242,13 @@ def bias_act(input: torch.Tensor,
return _bias_act_cuda(
dim=dim, act=act, alpha=alpha, gain=gain,
clamp=clamp).apply(input, bias)
try:
if use_custom_op and input.is_musa:
return _bias_act_musa(
dim=dim, act=act, alpha=alpha, gain=gain,
clamp=clamp).apply(input, bias)
except AttributeError:
pass
return _bias_act_ref(
input=input,
bias=bias,
Expand Down Expand Up @@ -373,3 +455,126 @@ def backward(ctx, d_dx): # pylint: disable=arguments-differ
# Add to cache.
_bias_act_cuda_cache[key] = BiasActCuda
return BiasActCuda


_bias_act_musa_cache: Dict = dict()


def _bias_act_musa(dim: int = 1,
act: str = 'linear',
alpha: Optional[Union[float, int]] = None,
gain: Optional[float] = None,
clamp: Optional[float] = None):
""""Fast MUSA implementation of `bias_act()` using custom ops.

Args:
dim (int): The dimension in `x` corresponding to the elements of `b`.
The value of `dim` is ignored if `b` is not specified.
Defaults to 1.
act (str): Name of the activation function to evaluate, or `"linear"`
to disable. Can be e.g. "relu", "lrelu", "tanh", "sigmoid",
"swish", etc. See `activation_funcs_musa` for a full list. `None`
is not allowed. Defaults to `linear`.
alpha (float | int): Shape parameter for the activation
function, or `None` to use the default. Defaults to None.
gain (float): Scaling factor for the output tensor, or `None`
to use default. See `activation_funcs_musa` for the default scaling
of each activation function. If unsure, consider specifying 1.
Defaults to None.
clamp (float): Clamp the output values to `[-clamp, +clamp]`,
or `None` to disable the clamping (default). Defaults to None.

Returns:
torch.Tensor: Tensor of the same shape and datatype as `x`.
"""
# Parse arguments.
assert clamp is None or clamp >= 0
spec = activation_funcs_musa[act]
alpha = float(alpha if alpha is not None else spec.def_alpha)
gain = float(gain if gain is not None else spec.def_gain)
clamp = float(clamp if clamp is not None else -1)

# Lookup from cache.
key = (dim, act, alpha, gain, clamp)
if key in _bias_act_musa_cache:
return _bias_act_musa_cache[key]

# Forward op.
class BiasActMusa(torch.autograd.Function):

@staticmethod
def forward(ctx, x, b): # pylint: disable=arguments-differ
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(
1) == 1 else torch.contiguous_format
x = x.contiguous(memory_format=ctx.memory_format)
b = b.contiguous() if b is not None else _null_tensor.to(x.device)
y = x
if act != 'linear' or gain != 1 or clamp >= 0 or (
b is not _null_tensor.to(x.device)):
y = ext_module.bias_act(x, b, _null_tensor.to(x.device),
_null_tensor.to(x.device),
_null_tensor.to(x.device), 0, dim,
spec.musa_idx, alpha, gain, clamp)
ctx.save_for_backward(
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor.to(
x.device), b if 'x' in spec.ref or spec.has_2nd_grad else
_null_tensor.to(x.device),
y if 'y' in spec.ref else _null_tensor.to(x.device))
return y

@staticmethod
def backward(ctx, dy): # pylint: disable=arguments-differ
dy = dy.contiguous(memory_format=ctx.memory_format)
x, b, y = ctx.saved_tensors
dx = None
db = None

if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
dx = dy
if act != 'linear' or gain != 1 or clamp >= 0:
dx = BiasActMusaGrad.apply(dy, x, b, y)

if ctx.needs_input_grad[1]:
db = dx.sum([i for i in range(dx.ndim) if i != dim])

return dx, db

# Backward op.
class BiasActMusaGrad(torch.autograd.Function):

@staticmethod
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
ctx.memory_format = torch.channels_last if dy.ndim > 2 and (
dy.stride(1) == 1) else torch.contiguous_format
dx = ext_module.bias_act(dy, b, x, y, _null_tensor.to(x.device), 1,
dim, spec.musa_idx, alpha, gain, clamp)
ctx.save_for_backward(
dy if spec.has_2nd_grad else _null_tensor.to(x.device), x, b,
y)
return dx

@staticmethod
def backward(ctx, d_dx): # pylint: disable=arguments-differ
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
dy, x, b, y = ctx.saved_tensors
d_dy = None
d_x = None
d_b = None
d_y = None

if ctx.needs_input_grad[0]:
d_dy = BiasActMusaGrad.apply(d_dx, x, b, y)

if spec.has_2nd_grad and (ctx.needs_input_grad[1]
or ctx.needs_input_grad[2]):
d_x = ext_module.bias_act(d_dx, b, x, y, dy, 2, dim,
spec.musa_idx, alpha, gain, clamp)

if spec.has_2nd_grad and ctx.needs_input_grad[2]:
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])

return d_dy, d_x, d_b, d_y

# Add to cache.
_bias_act_musa_cache[key] = BiasActMusa
return BiasActMusa
2 changes: 1 addition & 1 deletion mmcv/ops/carafe.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def forward(ctx, features: Tensor, masks: Tensor, kernel_size: int,
def backward(
ctx,
grad_output: Tensor) -> Tuple[Tensor, Tensor, None, None, None]:
assert grad_output.is_cuda
assert grad_output.is_cuda or grad_output.is_musa

features, masks = ctx.saved_tensors
kernel_size = ctx.kernel_size
Expand Down
7 changes: 5 additions & 2 deletions mmcv/ops/conv2d_gradfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Dict, Optional, Tuple, Union

import torch
from mmengine.device import is_cuda_available, is_musa_available
from mmengine.utils import digit_version
from mmengine.utils.dl_utils.parrots_wrapper import is_rocm_pytorch

Expand Down Expand Up @@ -95,6 +96,8 @@ def conv_transpose2d(input: torch.Tensor,

def _should_use_custom_op(input):
assert isinstance(input, torch.Tensor)
if enabled and is_musa_available():
return True
if (not enabled) or (not torch.backends.cudnn.enabled):
return False
if input.device.type != 'cuda':
Expand Down Expand Up @@ -177,8 +180,8 @@ def forward(ctx, input, weight, bias):
ctx.input_shape = input.shape

# Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere).
if weight_shape[2:] == stride == dilation == (
1, 1) and padding == (
if is_cuda_available() and weight_shape[
2:] == stride == dilation == (1, 1) and padding == (
0, 0) and torch.cuda.get_device_capability(
input.device) < (8, 0):
a = weight.reshape(groups, weight_shape[0] // groups,
Expand Down
4 changes: 2 additions & 2 deletions mmcv/ops/csrc/common/box_iou_rotated_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <cassert>
#include <cmath>

#ifdef __CUDACC__
#if defined(__CUDACC__) || defined(__MUSACC__)
// Designates functions callable from the host (CPU) and the device (GPU)
#define HOST_DEVICE __host__ __device__
#define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__
Expand Down Expand Up @@ -191,7 +191,7 @@ HOST_DEVICE_INLINE int convex_hull_graham(const Point<T> (&p)[24],
dist[i] = dot_2d<T>(q[i], q[i]);
}

#ifdef __CUDACC__
#if defined(__CUDACC__) || defined(__MUSACC__)
// CUDA version
// In the future, we can potentially use thrust
// for sorting here to improve speed (though not guaranteed)
Expand Down
56 changes: 56 additions & 0 deletions mmcv/ops/csrc/common/musa/active_rotated_filter_musa_kernel.muh
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright (c) OpenMMLab. All rights reserved.
// Modified from
// https://github.com/csuhan/s2anet/blob/master/mmdet/ops/orn/src/cuda/ActiveRotatingFilter_cuda.cu
#ifndef ACTIVE_ROTATED_FILTER_MUSA_KERNEL_MUH
#define ACTIVE_ROTATED_FILTER_MUSA_KERNEL_MUH


#include "pytorch_musa_helper.hpp"

template <typename scalar_t>
__global__ void active_rotated_filter_forward_musa_kernel(
const int nthreads, const scalar_t* weight_data, const int* indices_data,
const int num_input_planes, const int num_output_planes,
const int num_orientations, const int num_rotations, const int nEntry,
scalar_t* output_data) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
int l = index % nEntry;
int j = (index / nEntry) % num_input_planes;
int i = index / nEntry / num_input_planes;
int k;
scalar_t val = *(weight_data + index);
for (k = 0; k < num_rotations; k++) {
int idx = (int)(*(indices_data + l * num_rotations + k)) - 1;
scalar_t* target = output_data +
i * (num_rotations * num_input_planes * nEntry) +
k * (num_input_planes * nEntry) + j * (nEntry) + idx;
*target = val;
}
}
}

template <typename scalar_t>
__global__ void active_rotated_filter_backward_musa_kernel(
const int nthreads, const scalar_t* gradWeight_data,
const int* indices_data, const int num_input_planes,
const int num_output_planes, const int num_orientations,
const int num_rotations, const int nEntry, scalar_t* weight_data) {
MUSA_1D_KERNEL_LOOP(index, nthreads) {
int l = index % nEntry;
int j = (index / nEntry) % num_input_planes;
int i = index / nEntry / num_input_planes;
int k;
scalar_t* val = weight_data + index;
*val = 0;
scalar_t tmp = 0;
for (k = 0; k < num_rotations; k++) {
int idx = (int)(*(indices_data + l * num_rotations + k)) - 1;
scalar_t target =
*(gradWeight_data + i * (num_rotations * num_input_planes * nEntry) +
k * (num_input_planes * nEntry) + j * (nEntry) + idx);
tmp = tmp + target;
}
*val = tmp;
}
}
#endif // ACTIVE_ROTATED_FILTER_MUSA_KERNEL_MUH
Loading