diff --git a/mmcv/ops/csrc/common/cuda/softmax_focal_loss_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/softmax_focal_loss_cuda_kernel.cuh index 631b2c6175..4536c591d8 100644 --- a/mmcv/ops/csrc/common/cuda/softmax_focal_loss_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/softmax_focal_loss_cuda_kernel.cuh @@ -10,62 +10,100 @@ template __global__ void softmax_focal_loss_forward_cuda_kernel( - const int nthreads, const T* softmax, const int64_t* target, - const T* weight, T* output, const T gamma, const T alpha, - const int num_classes) { + const int nthreads, const T* __restrict__ log_softmax_prob, + const int64_t* __restrict__ target, const T* __restrict__ weight, + T* __restrict__ output, + const T gamma, const T alpha, const int num_classes) { CUDA_1D_KERNEL_LOOP(index, nthreads) { - int64_t label = target[index]; - T pred = softmax[index * num_classes + label]; + const int n = index / num_classes; + const int c = index % num_classes; - if (label >= 0) { - output[index] = - -alpha * pow((T)1. - pred, gamma) * log(max(pred, (T)FLT_MIN)); + // focal loss + // FL(p) = - alpha * (1-p)^gamma * log(p) if curr_class == label + // + // note that log_softmax_prob is calculated in Python part + // by using PyTorch API F.log_softmax() + const int64_t label = target[n]; + if (c == label) { + const T w = (weight != NULL) ? weight[label] : T(1); + const T alpha_fac = ((label == 0) * (1 - alpha) + (label >= 1) * alpha) * w; + + const T log_pred = log_softmax_prob[index]; + const T pred = exp(log_pred); + + output[index] = -alpha_fac * pow(1 - pred, gamma) * log_pred; } else { output[index] = 0; } - if (weight != NULL) { - output[index] *= weight[label]; - } } } template -__global__ void softmax_focal_loss_backward_cuda1_kernel( - const int nthreads, const T* softmax, const int64_t* target, - const T* weight, T* buff, const T gamma, const T alpha, - const int num_classes) { +__global__ void softmax_focal_loss_backward_cuda_kernel( + const int nthreads, const T* __restrict__ log_softmax_prob, + const int64_t* __restrict__ target, const T* __restrict__ weight, + T* __restrict__ sum_buff_along_class, T* __restrict__ grad_input, + const T gamma, const T alpha, const int num_classes) { + // forward node: x ----> p ----> FL + // func: SM FL + // + // backward node: x <---- p <---- FL + // index: j i FL + // + // For simplicity, the alpha of FL is ignored here + // dFL/dp = - [((1-p)^gamma) / p + // - gamma * (1-p)^(gamma-1) * log(p)] + // dp_i/dx_j = dSM/dx_j + // = p_i * (1-p_j) i==j; + // p_i * (0-p_j) i!=j; + // = p_i * (delta - p_j) where delta is Kronecker delta + // + // Replacing the p of dFL/dp with p_i, then + // dFL/dx_j = dFL/dp_i * dp_i/dx_j + // = - (delta - p_j) * [ (1-p_i)^gamma + // - gamma * (1-p_i)^(gamma-1) * log(p) * p_i] + // = (delta - p_j) * [- (1-p_i)^gamma + + // gamma * (1-p_i)^(gamma-1) * log(p) * p_i] + // + // Let B_i denote [- (1-p_i)^gamma + + // gamma * (1-p_i)^(gamma-1) * log(p) * p_i], + // and indices {i} is summed for all classes at index j + // since x_j received all the gradients from {p_i}. + // Then, dFL/dx_j = sum_i{ (delta - p_j) * B_i } + // = sum_i{ delta*B_i - p_j*B_i } + // = B_j - (p_j * sum_i{B_i}) + CUDA_1D_KERNEL_LOOP(index, nthreads) { - int64_t label = target[index]; - T pred = softmax[index * num_classes + label]; + // B_i + const int n = index / num_classes; + const int c = index % num_classes; - if (label >= 0) { - buff[index] = alpha * (-pow((T)1. - pred, gamma) + - gamma * pow((T)1. - pred, gamma - 1) * pred * - log(max(pred, (T)FLT_MIN))); + const int64_t label = target[n]; + if (c == label) { + const T w = (weight != NULL) ? weight[label] : T(1); + const T alpha_fac = ((label == 0) * (1 - alpha) + (label >= 1) * alpha) * w; + + const T log_pred = log_softmax_prob[index]; + const T pred = exp(log_pred); + const T one_minus_pred = 1 - pred; + + const T buff = alpha_fac * ( + -pow(one_minus_pred, gamma) + + gamma * pow(one_minus_pred, gamma - 1) * log_pred * pred + ); + grad_input[index] = buff; + sum_buff_along_class[n] += buff; } else { - buff[index] = 0; - } - if (weight != NULL) { - buff[index] *= weight[label]; + grad_input[index] = 0; } } -} -template -__global__ void softmax_focal_loss_backward_cuda2_kernel( - const int nthreads, const T* softmax, const int64_t* target, const T* buff, - T* grad_input, const int num_classes) { CUDA_1D_KERNEL_LOOP(index, nthreads) { - int n = index / num_classes; - int c = index % num_classes; - int64_t label = target[n]; + // dFL/dx_j + const int n = index / num_classes; - if (label >= 0) { - T flag = (label == c ? (T)1. : (T)0.); - grad_input[index] = buff[n] * (flag - softmax[index]); - } else { - grad_input[index] = 0; - } + const T pred = exp(log_softmax_prob[index]); + grad_input[index] -= pred * sum_buff_along_class[n]; } } diff --git a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp index d9359551d2..6efc2e165f 100644 --- a/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp +++ b/mmcv/ops/csrc/pytorch/cuda/cudabind.cpp @@ -409,13 +409,17 @@ void SigmoidFocalLossBackwardCUDAKernelLauncher(Tensor input, Tensor target, const float gamma, const float alpha); -void SoftmaxFocalLossForwardCUDAKernelLauncher(Tensor softmax, Tensor target, - Tensor weight, Tensor output, +void SoftmaxFocalLossForwardCUDAKernelLauncher(const Tensor log_softmax_prob, + const Tensor target, + const Tensor weight, + Tensor output, const float gamma, const float alpha); -void SoftmaxFocalLossBackwardCUDAKernelLauncher(Tensor softmax, Tensor target, - Tensor weight, Tensor buff, +void SoftmaxFocalLossBackwardCUDAKernelLauncher(const Tensor log_softmax_prob, + const Tensor target, + const Tensor weight, + Tensor sum_buff_along_class, Tensor grad_input, const float gamma, const float alpha); @@ -433,18 +437,26 @@ void sigmoid_focal_loss_backward_cuda(Tensor input, Tensor target, gamma, alpha); } -void softmax_focal_loss_forward_cuda(Tensor input, Tensor target, Tensor weight, - Tensor output, float gamma, float alpha) { - SoftmaxFocalLossForwardCUDAKernelLauncher(input, target, weight, output, - gamma, alpha); +void softmax_focal_loss_forward_cuda(const Tensor log_softmax_prob, + const Tensor target, + const Tensor weight, + Tensor output, + const float gamma, + const float alpha) { + SoftmaxFocalLossForwardCUDAKernelLauncher(log_softmax_prob, target, weight, + output, gamma, alpha); } -void softmax_focal_loss_backward_cuda(Tensor input, Tensor target, - Tensor weight, Tensor buff, - Tensor grad_input, float gamma, - float alpha) { - SoftmaxFocalLossBackwardCUDAKernelLauncher(input, target, weight, buff, - grad_input, gamma, alpha); +void softmax_focal_loss_backward_cuda(const Tensor log_softmax_prob, + const Tensor target, + const Tensor weight, + Tensor sum_buff_along_class, + Tensor grad_input, + const float gamma, + const float alpha) { + SoftmaxFocalLossBackwardCUDAKernelLauncher(log_softmax_prob, target, weight, + sum_buff_along_class, grad_input, + gamma, alpha); } void sigmoid_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, @@ -454,13 +466,20 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, Tensor weight, Tensor grad_input, float gamma, float alpha); -void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, - Tensor output, float gamma, float alpha); - -void softmax_focal_loss_backward_impl(Tensor input, Tensor target, - Tensor weight, Tensor buff, - Tensor grad_input, float gamma, - float alpha); +void softmax_focal_loss_forward_impl(const Tensor log_softmax_prob, + const Tensor target, + const Tensor weight, + Tensor output, + const float gamma, + const float alpha); + +void softmax_focal_loss_backward_impl(const Tensor log_softmax_prob, + const Tensor target, + const Tensor weight, + Tensor sum_buff_along_class, + Tensor grad_input, + const float gamma, + const float alpha); REGISTER_DEVICE_IMPL(sigmoid_focal_loss_forward_impl, CUDA, sigmoid_focal_loss_forward_cuda); diff --git a/mmcv/ops/csrc/pytorch/cuda/focal_loss_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/focal_loss_cuda.cu index cb899f954f..bfa902a70a 100644 --- a/mmcv/ops/csrc/pytorch/cuda/focal_loss_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/focal_loss_cuda.cu @@ -47,64 +47,53 @@ void SigmoidFocalLossBackwardCUDAKernelLauncher(Tensor input, Tensor target, AT_CUDA_CHECK(cudaGetLastError()); } -void SoftmaxFocalLossForwardCUDAKernelLauncher(Tensor softmax, Tensor target, - Tensor weight, Tensor output, +void SoftmaxFocalLossForwardCUDAKernelLauncher(const Tensor log_softmax_prob, + const Tensor target, + const Tensor weight, + Tensor output, const float gamma, const float alpha) { int output_size = output.numel(); - int num_classes = softmax.size(1); + int num_classes = log_softmax_prob.size(1); AT_ASSERTM(target.max().item() <= (int64_t)num_classes, "target label should smaller or equal than num classes"); - at::cuda::CUDAGuard device_guard(softmax.device()); + at::cuda::CUDAGuard device_guard(log_softmax_prob.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - softmax.scalar_type(), "softmax_focal_loss_forward_cuda_kernel", [&] { + log_softmax_prob.scalar_type(), "softmax_focal_loss_forward_cuda_kernel", [&] { softmax_focal_loss_forward_cuda_kernel <<>>( - output_size, softmax.data_ptr(), - target.data_ptr(), weight.data_ptr(), - output.data_ptr(), gamma, alpha, num_classes); + output_size, + log_softmax_prob.data_ptr(), target.data_ptr(), + weight.data_ptr(), output.data_ptr(), + gamma, alpha, num_classes); }); AT_CUDA_CHECK(cudaGetLastError()); } -void SoftmaxFocalLossBackwardCUDAKernelLauncher(Tensor softmax, Tensor target, - Tensor weight, Tensor buff, +void SoftmaxFocalLossBackwardCUDAKernelLauncher(const Tensor log_softmax_prob, + const Tensor target, + const Tensor weight, + Tensor sum_buff_along_class, Tensor grad_input, const float gamma, const float alpha) { - int num_classes = softmax.size(1); + int output_size = grad_input.numel(); + int num_classes = log_softmax_prob.size(1); - int output_size = buff.numel(); at::cuda::CUDAGuard device_guard(grad_input.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad_input.scalar_type(), - "softmax_focal_loss_backward_cuda1_" - "kernel", - [&] { - softmax_focal_loss_backward_cuda1_kernel - <<>>( - output_size, softmax.data_ptr(), - target.data_ptr(), weight.data_ptr(), - buff.data_ptr(), gamma, alpha, num_classes); - }); - - AT_CUDA_CHECK(cudaGetLastError()); - - output_size = grad_input.numel(); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - grad_input.scalar_type(), - "softmax_focal_loss_backward_cuda2_" - "kernel", - [&] { - softmax_focal_loss_backward_cuda2_kernel + log_softmax_prob.scalar_type(), "softmax_focal_loss_backward_cuda_kernel", [&] { + softmax_focal_loss_backward_cuda_kernel <<>>( - output_size, softmax.data_ptr(), - target.data_ptr(), buff.data_ptr(), - grad_input.data_ptr(), num_classes); + output_size, + log_softmax_prob.data_ptr(), target.data_ptr(), + weight.data_ptr(), sum_buff_along_class.data_ptr(), + grad_input.data_ptr(), + gamma, alpha, num_classes); }); AT_CUDA_CHECK(cudaGetLastError()); diff --git a/mmcv/ops/csrc/pytorch/focal_loss.cpp b/mmcv/ops/csrc/pytorch/focal_loss.cpp index 51568ead34..3eeb54ab77 100644 --- a/mmcv/ops/csrc/pytorch/focal_loss.cpp +++ b/mmcv/ops/csrc/pytorch/focal_loss.cpp @@ -25,18 +25,26 @@ void sigmoid_focal_loss_backward_impl(Tensor input, Tensor target, grad_input, gamma, alpha); } -void softmax_focal_loss_forward_impl(Tensor input, Tensor target, Tensor weight, - Tensor output, float gamma, float alpha) { - DISPATCH_DEVICE_IMPL(softmax_focal_loss_forward_impl, input, target, weight, - output, gamma, alpha); +void softmax_focal_loss_forward_impl(const Tensor log_softmax_prob, + const Tensor target, + const Tensor weight, + Tensor output, + const float gamma, + const float alpha) { + DISPATCH_DEVICE_IMPL(softmax_focal_loss_forward_impl, log_softmax_prob, + target, weight, output, gamma, alpha); } -void softmax_focal_loss_backward_impl(Tensor input, Tensor target, - Tensor weight, Tensor buff, - Tensor grad_input, float gamma, - float alpha) { - DISPATCH_DEVICE_IMPL(softmax_focal_loss_backward_impl, input, target, weight, - buff, grad_input, gamma, alpha); +void softmax_focal_loss_backward_impl(const Tensor log_softmax_prob, + const Tensor target, + const Tensor weight, + Tensor sum_buff_along_class, + Tensor grad_input, + const float gamma, + const float alpha) { + DISPATCH_DEVICE_IMPL(softmax_focal_loss_backward_impl, log_softmax_prob, + target, weight, sum_buff_along_class, grad_input, + gamma, alpha); } #ifdef MMCV_WITH_DIOPI @@ -127,14 +135,24 @@ void sigmoid_focal_loss_backward(Tensor input, Tensor target, Tensor weight, #endif } -void softmax_focal_loss_forward(Tensor input, Tensor target, Tensor weight, - Tensor output, float gamma, float alpha) { - softmax_focal_loss_forward_impl(input, target, weight, output, gamma, alpha); +void softmax_focal_loss_forward(const Tensor log_softmax_prob, + const Tensor target, + const Tensor weight, + Tensor output, + const float gamma, + const float alpha) { + softmax_focal_loss_forward_impl(log_softmax_prob, target, weight, + output, gamma, alpha); } -void softmax_focal_loss_backward(Tensor input, Tensor target, Tensor weight, - Tensor buff, Tensor grad_input, float gamma, - float alpha) { - softmax_focal_loss_backward_impl(input, target, weight, buff, grad_input, +void softmax_focal_loss_backward(const Tensor log_softmax_prob, + const Tensor target, + const Tensor weight, + Tensor sum_buff_along_class, + Tensor grad_input, + const float gamma, + const float alpha) { + softmax_focal_loss_backward_impl(log_softmax_prob, target, weight, + sum_buff_along_class, grad_input, gamma, alpha); } diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index c8591a5cc1..b65f4f63e6 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -103,12 +103,20 @@ void sigmoid_focal_loss_forward(Tensor input, Tensor target, Tensor weight, void sigmoid_focal_loss_backward(Tensor input, Tensor target, Tensor weight, Tensor grad_input, float gamma, float alpha); -void softmax_focal_loss_forward(Tensor input, Tensor target, Tensor weight, - Tensor output, float gamma, float alpha); - -void softmax_focal_loss_backward(Tensor input, Tensor target, Tensor weight, - Tensor buff, Tensor grad_input, float gamma, - float alpha); +void softmax_focal_loss_forward(const Tensor log_softmax_prob, + const Tensor target, + const Tensor weight, + Tensor output, + const float gamma, + const float alpha); + +void softmax_focal_loss_backward(const Tensor log_softmax_prob, + const Tensor target, + const Tensor weight, + Tensor sum_buff_along_class, + Tensor grad_input, + const float gamma, + const float alpha); void three_interpolate_forward(Tensor points_tensor, Tensor idx_tensor, Tensor weight_tensor, Tensor out_tensor, int b, @@ -566,13 +574,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("weight"), py::arg("grad_input"), py::arg("gamma"), py::arg("alpha")); m.def("softmax_focal_loss_forward", &softmax_focal_loss_forward, - "softmax_focal_loss_forward", py::arg("input"), py::arg("target"), - py::arg("weight"), py::arg("output"), py::arg("gamma"), - py::arg("alpha")); - m.def("softmax_focal_loss_backward", &softmax_focal_loss_backward, - "softmax_focal_loss_backward", py::arg("input"), py::arg("target"), - py::arg("weight"), py::arg("buff"), py::arg("grad_input"), + "softmax_focal_loss_forward", py::arg("log_softmax_prob"), + py::arg("target"), py::arg("weight"), py::arg("output"), py::arg("gamma"), py::arg("alpha")); + m.def("softmax_focal_loss_backward", &softmax_focal_loss_backward, + "softmax_focal_loss_backward", py::arg("log_softmax_prob"), + py::arg("target"), py::arg("weight"), py::arg("sum_buff_along_class"), + py::arg("grad_input"), py::arg("gamma"), py::arg("alpha")); m.def("three_interpolate_forward", &three_interpolate_forward, "three_interpolate_forward", py::arg("points_tensor"), py::arg("idx_tensor"), py::arg("weight_tensor"), py::arg("out_tensor"), diff --git a/mmcv/ops/focal_loss.py b/mmcv/ops/focal_loss.py index 69aab73052..f2d1d4e779 100644 --- a/mmcv/ops/focal_loss.py +++ b/mmcv/ops/focal_loss.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from torch.autograd import Function from torch.autograd.function import once_differentiable @@ -132,16 +133,13 @@ def forward(ctx, ctx.alpha = float(alpha) ctx.reduction = ctx.reduction_dict[reduction] - channel_stats, _ = torch.max(input, dim=1) - input_softmax = input - channel_stats.unsqueeze(1).expand_as(input) - input_softmax.exp_() + # log_softmax for numerical stability + log_softmax_prob = F.log_softmax(input, dim=1) - channel_stats = input_softmax.sum(dim=1) - input_softmax /= channel_stats.unsqueeze(1).expand_as(input) + output = input.new_zeros(input.size()) - output = input.new_zeros(input.size(0)) ext_module.softmax_focal_loss_forward( - input_softmax, + log_softmax_prob, target, weight, output, @@ -152,27 +150,30 @@ def forward(ctx, output = output.sum() / input.size(0) elif ctx.reduction == ctx.reduction_dict['sum']: output = output.sum() - ctx.save_for_backward(input_softmax, target, weight) + ctx.save_for_backward(log_softmax_prob, target, weight) return output @staticmethod + @once_differentiable def backward(ctx, grad_output: torch.Tensor) -> tuple: - input_softmax, target, weight = ctx.saved_tensors - buff = input_softmax.new_zeros(input_softmax.size(0)) - grad_input = input_softmax.new_zeros(input_softmax.size()) + log_softmax_prob, target, weight = ctx.saved_tensors + + sum_buff_along_class = log_softmax_prob.new_zeros( + log_softmax_prob.size(0)) + grad_input = log_softmax_prob.new_zeros(log_softmax_prob.size()) ext_module.softmax_focal_loss_backward( - input_softmax, + log_softmax_prob, target, weight, - buff, + sum_buff_along_class, grad_input, gamma=ctx.gamma, alpha=ctx.alpha) grad_input *= grad_output if ctx.reduction == ctx.reduction_dict['mean']: - grad_input /= input_softmax.size(0) + grad_input /= log_softmax_prob.size(0) return grad_input, None, None, None, None, None diff --git a/tests/test_ops/test_focal_loss.py b/tests/test_ops/test_focal_loss.py index ee7c9861ae..c56f63ee75 100644 --- a/tests/test_ops/test_focal_loss.py +++ b/tests/test_ops/test_focal_loss.py @@ -20,13 +20,13 @@ ([[1e-6, 2e-6, 3e-6], [4e-6, 5e-5, 6e-4], [7e-3, 8e-2, 9e-1]], [1, 2, 0]), ] -softmax_outputs = [(0.00566451, [[-0.00657264, 0.00657264], - [0.00657264, -0.00657264]]), - (0.34956908, [[0.10165970, 0.03739851, -0.13905823], - [0.01227554, -0.10298023, 0.09070466]]), - (0.15754992, [[0.02590877, -0.05181759, 0.02590882], - [0.02589641, 0.02589760, -0.05179400], - [-0.07307514, 0.02234372, 0.05073142]])] +softmax_outputs = [(0.01132904, [[-0.01971794, 0.01971793], + [0.00657264, -0.00657265]]), + (0.34956908, [[0.10165971, 0.03739851, -0.13905823], + [0.01227554, -0.10298022, 0.09070467]]), + (0.30995172, [[0.02590877, -0.05181758, 0.02590882], + [0.02589641, 0.02589760, -0.05179401], + [-0.21922545, 0.06703118, 0.15219429]])] sigmoid_outputs = [(0.13562961, [[-0.00657264, 0.11185755], [0.11185755, -0.00657264]]),