Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix softmax focal loss algorithm #2893

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 101 additions & 38 deletions mmcv/ops/csrc/common/cuda/softmax_focal_loss_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,62 +10,125 @@

template <typename T>
__global__ void softmax_focal_loss_forward_cuda_kernel(
const int nthreads, const T* softmax, const int64_t* target,
const T* weight, T* output, const T gamma, const T alpha,
const int num_classes) {
const int nthreads, const T* input, const int64_t* target,
const T* weight, T* output, T* log_softmax_prob,
const T gamma, const T alpha, const int num_classes) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int64_t label = target[index];
T pred = softmax[index * num_classes + label];
const int n = index / num_classes;
const int c = index % num_classes;

// log_softmax for numerical stability
const int start = n * num_classes;
const int end = start + num_classes;

const T max_val_along_class = [&] {
T max_val_along_class = -FLT_MAX;
for(int c_idx = start; c_idx < end; ++c_idx) {
max_val_along_class = max(max_val_along_class, input[c_idx]);
qingpeng9802 marked this conversation as resolved.
Show resolved Hide resolved
}
return max_val_along_class;
}();

const T expsum = [&] {
T expsum = 0;
for(int c_idx = start; c_idx < end; ++c_idx) {
expsum += exp(input[c_idx] - max_val_along_class);
}
return expsum;
}();

const T log_pred = input[index] - max_val_along_class - log(expsum);
log_softmax_prob[index] = log_pred;

if (label >= 0) {
output[index] =
-alpha * pow((T)1. - pred, gamma) * log(max(pred, (T)FLT_MIN));
// focal loss
// FL(p) = - alpha * (1-p)^gamma * log(p) if curr_class == label
const int64_t label = target[n];
if (c == label) {
const T w = (weight != NULL) ? weight[label] : 1;
const T alpha_fac = ((label == 0) * (1 - alpha) + (label >= 1) * alpha) * w;

const T pred = exp(log_pred);

output[index] = -alpha_fac * pow(1 - pred, gamma) * log_pred;
} else {
output[index] = 0;
}
if (weight != NULL) {
output[index] *= weight[label];
}
}
}

template <typename T>
__global__ void softmax_focal_loss_backward_cuda1_kernel(
const int nthreads, const T* softmax, const int64_t* target,
const T* weight, T* buff, const T gamma, const T alpha,
__global__ void softmax_focal_loss_backward_cuda_kernel(
const int nthreads, const T* log_softmax_prob, const int64_t* target,
const T* weight, T* grad_input, const T gamma, const T alpha,
const int num_classes) {
// forward: x ----> p ----> FL
// SM FL
// backward: x <---- p <---- FL,
// j i FL
// For simplicity, the alpha of FL is ignored here
// dFL/dp = - [((1-p)^gamma) / p
// - gamma * (1-p)^(gamma-1) * log(p)]
// dp_i/dx_j = dSM/dx_j
// = p_i * (1-p_j) i==j;
// p_i * (0-p_j) i!=j;
// = p_i * (delta - p_j) where delta is Kronecker delta
//
// Replacing the p of dFL/dp with p_i, then
// dFL/dx_j = dFL/dp_i * dp_i/dx_j
// = - (delta - p_j) * [ (1-p_i)^gamma
// - gamma * (1-p_i)^(gamma-1) * log(p) * p_i]
// = (delta - p_j) * [- (1-p_i)^gamma +
// gamma * (1-p_i)^(gamma-1) * log(p) * p_i]
//
// Let B_i denote [- (1-p_i)^gamma +
// gamma * (1-p_i)^(gamma-1) * log(p) * p_i],
// and indices {i} is summed for all classes at index j
// since x_j received all the gradients from {p_i}.
// Then, dFL/dx_j = sum_i{ (delta - p_j) * B_i }
// = sum_i{ (delta*B_i - p_j*B_i }
// = B_j - (p_j * sum_i{B_i})

CUDA_1D_KERNEL_LOOP(index, nthreads) {
int64_t label = target[index];
T pred = softmax[index * num_classes + label];
// B_i
const int n = index / num_classes;
const int c = index % num_classes;

const int64_t label = target[n];
if (c == label) {
const T w = (weight != NULL) ? weight[label] : 1;
const T alpha_fac = ((label == 0) * (1 - alpha) + (label >= 1) * alpha) * w;

if (label >= 0) {
buff[index] = alpha * (-pow((T)1. - pred, gamma) +
gamma * pow((T)1. - pred, gamma - 1) * pred *
log(max(pred, (T)FLT_MIN)));
const T log_pred = log_softmax_prob[index];
const T pred = exp(log_pred);
const T one_minus_pred = 1 - pred;

grad_input[index] =
alpha_fac * (
-pow(one_minus_pred, gamma) +
gamma * pow(one_minus_pred, gamma - 1) * log_pred * pred
);
} else {
buff[index] = 0;
}
if (weight != NULL) {
buff[index] *= weight[label];
grad_input[index] = 0;
}
}
}

template <typename T>
__global__ void softmax_focal_loss_backward_cuda2_kernel(
const int nthreads, const T* softmax, const int64_t* target, const T* buff,
T* grad_input, const int num_classes) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int n = index / num_classes;
int c = index % num_classes;
int64_t label = target[n];
// dFL/dx_j
const int n = index / num_classes;

if (label >= 0) {
T flag = (label == c ? (T)1. : (T)0.);
grad_input[index] = buff[n] * (flag - softmax[index]);
} else {
grad_input[index] = 0;
}
const int start = n * num_classes;
const int end = start + num_classes;

const T sum_buff_along_class = [&] {
T sum_buff_along_class = 0;
for(int c_idx = start; c_idx < end; ++c_idx) {
sum_buff_along_class += grad_input[c_idx];
qingpeng9802 marked this conversation as resolved.
Show resolved Hide resolved
}
return sum_buff_along_class;
}();

const T pred = exp(log_softmax_prob[index]);
grad_input[index] -= pred * sum_buff_along_class;
}
}

Expand Down
55 changes: 20 additions & 35 deletions mmcv/ops/csrc/pytorch/cuda/focal_loss_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,64 +47,49 @@ void SigmoidFocalLossBackwardCUDAKernelLauncher(Tensor input, Tensor target,
AT_CUDA_CHECK(cudaGetLastError());
}

void SoftmaxFocalLossForwardCUDAKernelLauncher(Tensor softmax, Tensor target,
void SoftmaxFocalLossForwardCUDAKernelLauncher(Tensor input, Tensor target,
Tensor weight, Tensor output,
Tensor log_softmax_prob,
const float gamma,
const float alpha) {
int output_size = output.numel();
int num_classes = softmax.size(1);
int num_classes = input.size(1);

AT_ASSERTM(target.max().item<int64_t>() <= (int64_t)num_classes,
"target label should smaller or equal than num classes");
at::cuda::CUDAGuard device_guard(softmax.device());
at::cuda::CUDAGuard device_guard(input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
softmax.scalar_type(), "softmax_focal_loss_forward_cuda_kernel", [&] {
input.scalar_type(), "softmax_focal_loss_forward_cuda_kernel", [&] {
softmax_focal_loss_forward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, softmax.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(), weight.data_ptr<scalar_t>(),
output.data_ptr<scalar_t>(), gamma, alpha, num_classes);
output_size,
input.data_ptr<scalar_t>(), target.data_ptr<int64_t>(),
weight.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
log_softmax_prob.data_ptr<scalar_t>(),
gamma, alpha, num_classes);
});

AT_CUDA_CHECK(cudaGetLastError());
}

void SoftmaxFocalLossBackwardCUDAKernelLauncher(Tensor softmax, Tensor target,
Tensor weight, Tensor buff,
Tensor grad_input,
void SoftmaxFocalLossBackwardCUDAKernelLauncher(Tensor log_softmax_prob, Tensor target,
Tensor weight, Tensor grad_input,
const float gamma,
const float alpha) {
int num_classes = softmax.size(1);
int output_size = grad_input.numel();
int num_classes = log_softmax_prob.size(1);

int output_size = buff.numel();
at::cuda::CUDAGuard device_guard(grad_input.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_input.scalar_type(),
"softmax_focal_loss_backward_cuda1_"
"kernel",
[&] {
softmax_focal_loss_backward_cuda1_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, softmax.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(), weight.data_ptr<scalar_t>(),
buff.data_ptr<scalar_t>(), gamma, alpha, num_classes);
});

AT_CUDA_CHECK(cudaGetLastError());

output_size = grad_input.numel();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_input.scalar_type(),
"softmax_focal_loss_backward_cuda2_"
"kernel",
[&] {
softmax_focal_loss_backward_cuda2_kernel<scalar_t>
log_softmax_prob.scalar_type(), "softmax_focal_loss_backward_cuda_kernel", [&] {
softmax_focal_loss_backward_cuda_kernel<scalar_t>
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, softmax.data_ptr<scalar_t>(),
target.data_ptr<int64_t>(), buff.data_ptr<scalar_t>(),
grad_input.data_ptr<scalar_t>(), num_classes);
output_size,
log_softmax_prob.data_ptr<scalar_t>(), target.data_ptr<int64_t>(),
weight.data_ptr<scalar_t>(), grad_input.data_ptr<scalar_t>(),
gamma, alpha, num_classes);
});

AT_CUDA_CHECK(cudaGetLastError());
Expand Down
21 changes: 12 additions & 9 deletions mmcv/ops/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,40 +139,43 @@ def forward(ctx,
channel_stats = input_softmax.sum(dim=1)
input_softmax /= channel_stats.unsqueeze(1).expand_as(input)

output = input.new_zeros(input.size(0))
log_softmax_prob = input.new_zeros(input.size())
output = input.new_zeros(input.size())

ext_module.softmax_focal_loss_forward(
input_softmax,
input,
target,
weight,
output,
log_softmax_prob,
gamma=ctx.gamma,
alpha=ctx.alpha)

if ctx.reduction == ctx.reduction_dict['mean']:
output = output.sum() / input.size(0)
elif ctx.reduction == ctx.reduction_dict['sum']:
output = output.sum()
ctx.save_for_backward(input_softmax, target, weight)
ctx.save_for_backward(log_softmax_prob, target, weight)
return output

@staticmethod
@once_differentiable
def backward(ctx, grad_output: torch.Tensor) -> tuple:
input_softmax, target, weight = ctx.saved_tensors
buff = input_softmax.new_zeros(input_softmax.size(0))
grad_input = input_softmax.new_zeros(input_softmax.size())
log_softmax_prob, target, weight = ctx.saved_tensors

grad_input = log_softmax_prob.new_zeros(log_softmax_prob.size())

ext_module.softmax_focal_loss_backward(
input_softmax,
log_softmax_prob,
target,
weight,
buff,
grad_input,
gamma=ctx.gamma,
alpha=ctx.alpha)

grad_input *= grad_output
if ctx.reduction == ctx.reduction_dict['mean']:
grad_input /= input_softmax.size(0)
grad_input /= log_softmax_prob.size(0)
return grad_input, None, None, None, None, None


Expand Down