diff --git a/oneflow/python/ops/loss_ops.py b/oneflow/python/ops/loss_ops.py index b1ba877b6ba810662c77b5c16adfc00c79523970..6195a11e485e9a5050aa738296094da3a574c999 100644 --- a/oneflow/python/ops/loss_ops.py +++ b/oneflow/python/ops/loss_ops.py @@ -85,3 +85,126 @@ def smooth_l1_loss( ) op.Attr("beta", float(beta)) return op.Build().InferAndTryRun().RemoteBlobList()[0] + + +@oneflow_export("ctc_loss") +def ctc_loss( + log_probs: oneflow_api.BlobDesc, + targets: oneflow_api.BlobDesc, + input_lengths: oneflow_api.BlobDesc, + target_lengths: oneflow_api.BlobDesc, + blank: int = 0, + reduction: str = "mean", + zero_infinity: bool = False, + name: Optional[str] = None, +) -> oneflow_api.BlobDesc: + r"""Computes the CTC(Connectionist Temporal Classification) loss. + This operator implements the CTC loss as presented in (Graves et al., 2006). + + + Args: + log_probs (oneflow_api.BlobDesc): A Blob of shape [input_length, batch_size, num_labels]. The logarithmized probabilities of the outputs (e.g. obtained with flow.nn.logsoftmax()). + targets (oneflow_api.BlobDesc): A Blob of shape [batch_size, max_target_length]. It represent the target sequences. Each element in the target sequence is a class index. And the target index cannot be blank (default=0). + input_lengths (oneflow_api.BlobDesc): A Blob of shape [batch_size]. It represent the lengths of the inputs. And the lengths are specified for each sequence to achieve masking under the assumption that sequences are padded to equal lengths. + target_lengths (oneflow_api.BlobDesc): A Blob of shape [batch_size]. It represent lengths of the targets. Lengths are specified for each sequence to achieve masking under the assumption that sequences are padded to equal lengths. + blank (int, optional): Blank label. Defaults to 0. + reduction (str, optional): The reduce type, it can be the one of "none", "mean", "sum". "none": no reduction will be applied, "mean": the output losses will be divided by the target lengths and then the mean over the batch is taken, "sum": the output will be summed. Defaults to "mean". + zero_infinity (bool, optional): Whether to zero infinite losses and the associated gradients. Infinite losses mainly occur when the inputs are too short to be aligned to the targets. Defaults to False. + name (Optional[str], optional): The name for the operation. Defaults to None. + + Returns: + oneflow_api.BlobDesc: The result Blob. + + For example: + + .. code-block:: python + + import oneflow as flow + import oneflow.typing as tp + import numpy as np + + + @flow.global_function() + def ctc_loss_job( + log_probs: tp.Numpy.Placeholder(shape=(5, 2, 3)), + targets: tp.Numpy.Placeholder(shape=(2, 3), dtype=flow.int32), + input_lengths: tp.Numpy.Placeholder(shape=(2,), dtype=flow.int32), + target_lengths: tp.Numpy.Placeholder(shape=(2,), dtype=flow.int32), + ) -> tp.Numpy: + loss = flow.ctc_loss( + log_probs, targets, input_lengths, target_lengths, blank=0, reduction="none" + ) + return loss + + + log_probs = np.array( + [ + [[-1.1031, -0.7998, -1.5200], [-0.9808, -1.1363, -1.1908]], + [[-1.2258, -1.0665, -1.0153], [-1.1135, -1.2331, -0.9671]], + [[-1.3348, -0.6611, -1.5118], [-0.9823, -1.2355, -1.0941]], + [[-1.3850, -1.3273, -0.7247], [-0.8235, -1.4783, -1.0994]], + [[-0.9049, -0.8867, -1.6962], [-1.4938, -1.3630, -0.6547]], + ] + ).astype(np.float32) + targets = np.array([[1, 2, 2], [1, 2, 2]]).astype("int32") + input_lengths = np.array([5, 5]).astype("int32") + target_lengths = np.array([3, 3]).astype("int32") + loss = ctc_loss_job(log_probs, targets, input_lengths, target_lengths) + + # loss [3.918017 2.907672] + + """ + name = name if name is not None else id_util.UniqueStr("CTCLoss_") + loss, _ = ( + flow.user_op_builder(name) + .Op("ctc_loss") + .Input("log_probs", [log_probs]) + .Input("targets", [targets]) + .Input("input_lengths", [input_lengths]) + .Input("target_lengths", [target_lengths]) + .Output("loss") + .Output("alpha") + .Attr("blank", int(blank)) + .Attr("zero_infinity", zero_infinity) + .Build() + .InferAndTryRun() + .RemoteBlobList() + ) + + if zero_infinity: + cond = flow.math.equal( + loss, + flow.constant( + float("inf"), + dtype=loss.dtype, + shape=loss.shape, + name=name + "_constant", + ), + name=name + "_equal", + ) + loss = flow.where( + cond, + flow.zeros(dtype=loss.dtype, shape=loss.shape, name=name + "_zeros"), + loss, + name=name + "_where", + ) + + if reduction == "mean": + return flow.math.reduce_mean( + flow.math.xdivy( + loss, + flow.cast( + flow.math.clip_by_value( + target_lengths, min_value=1, name=name + "_clip_by_value" + ), + dtype=log_probs.dtype, + name=name + "_cast", + ), + name=name + "_xdivy", + ), + name=name + "_reduce_mean", + ) + elif reduction == "sum": + return flow.math.reduce_sum(loss, name=name + "_reduce_sum") + else: + return loss diff --git a/oneflow/python/test/ops/test_ctc_loss.py b/oneflow/python/test/ops/test_ctc_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..f60ac036d4e684e298f55b0ba7f080c0d351ef8b --- /dev/null +++ b/oneflow/python/test/ops/test_ctc_loss.py @@ -0,0 +1,352 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +from collections import OrderedDict + +import numpy as np +import oneflow as flow +from test_util import GenArgList, type_name_to_flow_type, type_name_to_np_type +import oneflow.typing as tp +import os + +ninf = -float("inf") + + +def _logsumexp(a, b): + if a < b: + a, b = b, a + if b == ninf: + return a + else: + return a + np.log(1 + np.exp(b - a)) + + +def logsumexp(*args): + res = args[0] + for e in args[1:]: + res = _logsumexp(res, e) + return res + + +def log_softmax(logits, axis=0): + max_value = np.max(logits, axis, keepdims=True) + exp = np.exp(logits - max_value) + exp_sum = np.sum(exp, axis, keepdims=True) + dist = exp / exp_sum + return np.log(dist) + + +def get_target_prime(targets, b, s, blank): + if s % 2 == 0: + return blank + else: + return targets[b, s // 2] + + +def ctc_loss_np(log_probs, targets, input_lengths, target_lengths, blank=0): + + max_input_length, batch_size, _ = log_probs.shape + _, max_target_length = targets.shape + loss = np.zeros(batch_size) + alpha = np.zeros([batch_size, max_input_length, 2 * max_target_length + 1]) + alpha[:, 0] = ninf + + for b in range(0, batch_size): + input_length = input_lengths[b] + target_length = target_lengths[b] + alpha[b, 0, 0] = log_probs[0, b, blank] + if target_length > 0: + current_target_prime = get_target_prime(targets, b, 1, blank) + alpha[b, 0, 1] = log_probs[0, b, current_target_prime] + + for t in range(1, input_length): + for s in range(0, 2 * target_length + 1): + current_target_prime = get_target_prime(targets, b, s, blank) + la1 = alpha[b, t - 1, s] + if s > 0: + la2 = alpha[b, t - 1, s - 1] + else: + la2 = ninf + if ( + s > 1 + and get_target_prime(targets, b, s - 2, blank) + != current_target_prime + ): + la3 = alpha[b, t - 1, s - 2] + else: + la3 = ninf + + alpha[b, t, s] = ( + logsumexp(la1, la2, la3) + log_probs[t, b, current_target_prime] + ) + + if target_length == 0: + loss[b] = -alpha[b, input_length - 1, 0] + else: + l1 = alpha[b, input_length - 1, target_length * 2] + l2 = alpha[b, input_length - 1, target_length * 2 - 1] + loss[b] = -logsumexp(l1, l2) + return loss, alpha + + +def ctc_loss_grad_np( + grad_out, + loss, + alpha, + log_probs, + targets, + input_lengths, + target_lengths, + blank=0, + zero_infinity=False, +): + max_input_length, batch_size, num_labels = log_probs.shape + _, max_target_length = targets.shape + + beta = np.zeros([batch_size, max_input_length, 2 * max_target_length + 1]) + grad = np.zeros(log_probs.shape, dtype=log_probs.dtype) + grad.fill(ninf) + + for b in range(0, batch_size): + input_length = input_lengths[b] + target_length = target_lengths[b] + nll = loss[b] + if zero_infinity and nll == float("inf"): + grad[:, b, :] = 0 + continue + + if input_length > 0: + beta[b, input_length - 1, :] = ninf + beta[b, input_length - 1, 2 * target_length] = log_probs[ + input_length - 1, b, blank + ] + grad[input_length - 1, b, blank] = ( + alpha[b, input_length - 1, 2 * target_length] + + beta[b, input_length - 1, 2 * target_length] + ) + + if target_length > 0: + current_target_prime = get_target_prime( + targets, b, 2 * target_length - 1, blank + ) + beta[b, input_length - 1, 2 * target_length - 1] = log_probs[ + input_length - 1, b, current_target_prime + ] + grad[input_length - 1, b, current_target_prime] = ( + alpha[b, input_length - 1, 2 * target_length - 1] + + beta[b, input_length - 1, 2 * target_length - 1] + ) + + for t in range(input_length - 2, -1, -1): + for s in range(2 * target_length, -1, -1): + current_target_prime = get_target_prime(targets, b, s, blank) + lb1 = beta[b, t + 1, s] + if s < 2 * target_length: + lb2 = beta[b, t + 1, s + 1] + else: + lb2 = ninf + if ( + s < 2 * target_length - 1 + and get_target_prime(targets, b, s + 2, blank) + != current_target_prime + ): + lb3 = beta[b, t + 1, s + 2] + else: + lb3 = ninf + + beta[b, t, s] = ( + logsumexp(lb1, lb2, lb3) + log_probs[t, b, current_target_prime] + ) + alpha_beta = alpha[b, t, s] + beta[b, t, s] + lcab = grad[t, b, current_target_prime] + if lcab == ninf: + grad[t, b, current_target_prime] = alpha_beta + else: + grad[t, b, current_target_prime] = logsumexp(lcab, alpha_beta) + + for t in range(0, input_length): + for c in range(0, num_labels): + res = grad[t, b, c] + lp = log_probs[t, b, c] + grad[t, b, c] = (np.exp(lp) - np.exp(res + nll - lp)) * grad_out[b] + if input_length < max_input_length: + grad[input_length:max_input_length, b] = 0 + return grad + + +def compare_with_np( + device_type, + device_num, + data_type, + max_input_length, + batch_size, + num_classes, + max_target_length, + blank, + reduction, + zero_infinity, +): + assert data_type in ["float32", "double"] + assert device_type in ["gpu", "cpu"] + assert reduction in ["none", "mean", "sum"] + assert zero_infinity in [False, True] + + flow.clear_default_session() + if device_type == "cpu": + flow.config.cpu_device_num(device_num) + else: + flow.config.gpu_device_num(device_num) + flow_data_type = type_name_to_flow_type[data_type] + np_data_type = type_name_to_np_type[data_type] + func_config = flow.FunctionConfig() + func_config.default_logical_view(flow.scope.consistent_view()) + func_config.default_data_type(flow_data_type) + func_config.default_placement_scope( + flow.scope.placement(device_type, "0:0-{}".format(device_num - 1)) + ) + + log_probs = np.random.random( + size=(max_input_length, batch_size, num_classes) + ).astype(np_data_type) + log_probs = log_softmax(log_probs, axis=2) + targets = np.random.randint( + 1, high=num_classes, size=(batch_size, max_target_length), dtype=np.int32 + ) + input_lengths = np.random.randint( + max_input_length / 2, high=max_input_length, size=(batch_size,), dtype=np.int32 + ) + target_lengths = np.random.randint( + max_target_length / 2, + high=max_target_length, + size=(batch_size,), + dtype=np.int32, + ) + + np_loss, np_alpha = ctc_loss_np( + log_probs, targets, input_lengths, target_lengths, blank + ) + + np_out = np.where(np_loss == float("inf"), 0, np_loss) if zero_infinity else np_loss + if reduction == "mean": + np_out = np.mean( + np.divide( + np_out, np.clip(target_lengths, 1, a_max=None).astype(np_data_type) + ) + ) + elif reduction == "sum": + np_out = np.sum(np_out) + + np_grad_out = np.ones_like(np_loss, dtype=np_data_type) + if reduction == "mean": + np_grad_out = np.divide( + np_grad_out, np.clip(target_lengths, 1, a_max=None).astype(np_data_type) + ) + np_grad_out /= target_lengths.size + + np_grad = ctc_loss_grad_np( + np_grad_out, + np_loss, + np_alpha, + log_probs, + targets, + input_lengths, + target_lengths, + blank, + zero_infinity, + ) + + def assert_loss_grad(blob: tp.Numpy): + assert np.allclose(blob, np_grad, atol=1e-5, equal_nan=True) + + @flow.global_function(type="train", function_config=func_config) + def ctc_loss_job( + log_probs: tp.Numpy.Placeholder( + shape=(max_input_length, batch_size, num_classes), dtype=flow_data_type + ), + targets: tp.Numpy.Placeholder( + shape=(batch_size, max_target_length), dtype=flow.int32 + ), + input_lengths: tp.Numpy.Placeholder(shape=(batch_size,), dtype=flow.int32), + target_lengths: tp.Numpy.Placeholder(shape=(batch_size,), dtype=flow.int32), + ) -> tp.Numpy: + with flow.scope.placement(device_type, "0:0"): + v = flow.get_variable( + shape=log_probs.shape, + dtype=flow_data_type, + initializer=flow.zeros_initializer(), + name="x_var", + ) + x_var = log_probs + v + + flow.watch_diff(x_var, assert_loss_grad) + loss = flow.ctc_loss( + x_var, + targets, + input_lengths, + target_lengths, + blank, + reduction, + zero_infinity, + ) + + with flow.scope.placement(device_type, "0:0"): + flow.optimizer.SGD( + flow.optimizer.PiecewiseConstantScheduler([], [1e-3]), momentum=0 + ).minimize(loss) + + return loss + + of_out = ctc_loss_job(log_probs, targets, input_lengths, target_lengths) + assert np.allclose(of_out, np_out, atol=1e-5) + + +def gen_arg_list(type): + arg_dict = OrderedDict() + if type == "1n2d": + arg_dict["device_type"] = ["gpu"] + arg_dict["device_num"] = [2] + else: + arg_dict["device_type"] = ["cpu", "gpu"] + arg_dict["device_num"] = [1] + arg_dict["data_type"] = ["float32"] + arg_dict["max_input_length"] = [20] + arg_dict["batch_size"] = [4] + arg_dict["num_classes"] = [5] + arg_dict["max_target_length"] = [10] + arg_dict["blank"] = [0, 4] # 0 <= blank < num_classes + arg_dict["reduction"] = ["mean", "none"] + arg_dict["zero_infinity"] = [False, True] + + return GenArgList(arg_dict) + + +@flow.unittest.skip_unless_1n1d() +class TestCTCLoss1n1d(flow.unittest.TestCase): + def test_ctc_loss(test_case): + for arg in gen_arg_list("1n1d"): + compare_with_np(*arg) + + +@flow.unittest.skip_unless_1n2d() +class TestCTCLoss1n2d(flow.unittest.TestCase): + @unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases") + def test_ctc_loss(test_case): + for arg in gen_arg_list("1n2d"): + compare_with_np(*arg) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/user/kernels/ctc_loss_kernel.cpp b/oneflow/user/kernels/ctc_loss_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..23a2dba993c4a2fc5b889cc22e98530fc219a603 --- /dev/null +++ b/oneflow/user/kernels/ctc_loss_kernel.cpp @@ -0,0 +1,142 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" +#include "oneflow/user/kernels/ctc_loss_kernel_util.h" + +namespace oneflow { + +template<DeviceType device_type, typename T, typename IDX> +class CtcLossKernel final : public user_op::OpKernel { + public: + CtcLossKernel() = default; + ~CtcLossKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* log_probs = ctx->Tensor4ArgNameAndIndex("log_probs", 0); + const user_op::Tensor* targets = ctx->Tensor4ArgNameAndIndex("targets", 0); + const user_op::Tensor* input_lengths = ctx->Tensor4ArgNameAndIndex("input_lengths", 0); + const user_op::Tensor* target_lengths = ctx->Tensor4ArgNameAndIndex("target_lengths", 0); + user_op::Tensor* loss = ctx->Tensor4ArgNameAndIndex("loss", 0); + user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex("alpha", 0); + + const T* log_probs_ptr = log_probs->dptr<T>(); + const int* targets_ptr = targets->dptr<int>(); + const IDX* input_lengths_ptr = input_lengths->dptr<IDX>(); + const IDX* target_lengths_ptr = target_lengths->dptr<IDX>(); + const int blank = ctx->Attr<int>("blank"); + const int64_t max_input_length = log_probs->shape().At(0); + const int64_t batch_size = log_probs->shape().At(1); + const int64_t num_labels = log_probs->shape().At(2); + const int64_t max_target_length = targets->shape().At(1); + CHECK_EQ(batch_size, targets->shape().At(0)); + CHECK_EQ(batch_size, input_lengths->shape().At(0)); + CHECK_EQ(batch_size, target_lengths->shape().At(0)); + CHECK_GE(blank, 0); + CHECK_LT(blank, num_labels); + NdIndexOffsetHelper<int64_t, 3> input_helper(max_input_length, batch_size, num_labels); + NdIndexOffsetHelper<int64_t, 3> alpha_helper(batch_size, max_input_length, + 2 * max_target_length + 1); + T* loss_ptr = loss->mut_dptr<T>(); + T* alpha_ptr = alpha->mut_dptr<T>(); + CtcLossKernelUtil<device_type, T, IDX>::CtcLossForward( + ctx->device_ctx(), log_probs_ptr, targets_ptr, input_lengths_ptr, target_lengths_ptr, + alpha_ptr, loss_ptr, input_helper, alpha_helper, batch_size, max_input_length, + max_target_length, blank); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_CTC_LOSS_KERNEL(device, dtype, idx_dtype) \ + REGISTER_USER_KERNEL("ctc_loss") \ + .SetCreateFn<CtcLossKernel<device, OF_PP_PAIR_FIRST(dtype), OF_PP_PAIR_FIRST(idx_dtype)>>() \ + .SetIsMatchedHob( \ + (user_op::HobDeviceTag() == device) \ + & (user_op::HobDataType("log_probs", 0) == OF_PP_PAIR_SECOND(dtype)) \ + & (user_op::HobDataType("input_lengths", 0) == OF_PP_PAIR_SECOND(idx_dtype))); + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CTC_LOSS_KERNEL, DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ, + INDEX_DATA_TYPE_SEQ) + +template<DeviceType device_type, typename T, typename IDX> +class CtcLossGradKernel final : public user_op::OpKernel { + public: + CtcLossGradKernel() = default; + ~CtcLossGradKernel() = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* grad_out = ctx->Tensor4ArgNameAndIndex("grad_out", 0); + const user_op::Tensor* loss = ctx->Tensor4ArgNameAndIndex("loss", 0); + const user_op::Tensor* alpha = ctx->Tensor4ArgNameAndIndex("alpha", 0); + const user_op::Tensor* log_probs = ctx->Tensor4ArgNameAndIndex("log_probs", 0); + const user_op::Tensor* targets = ctx->Tensor4ArgNameAndIndex("targets", 0); + const user_op::Tensor* input_lengths = ctx->Tensor4ArgNameAndIndex("input_lengths", 0); + const user_op::Tensor* target_lengths = ctx->Tensor4ArgNameAndIndex("target_lengths", 0); + user_op::Tensor* grad = ctx->Tensor4ArgNameAndIndex("grad", 0); + user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + + const T* grad_out_ptr = grad_out->dptr<T>(); + const T* loss_ptr = loss->dptr<T>(); + const T* alpha_ptr = alpha->dptr<T>(); + const T* log_probs_ptr = log_probs->dptr<T>(); + const int* targets_ptr = targets->dptr<int>(); + const IDX* input_lengths_ptr = input_lengths->dptr<IDX>(); + const IDX* target_lengths_ptr = target_lengths->dptr<IDX>(); + const int blank = ctx->Attr<int>("blank"); + const bool zero_infinity = ctx->Attr<bool>("zero_infinity"); + const int64_t batch_size = log_probs->shape().At(1); + const int64_t num_labels = log_probs->shape().At(2); + CHECK_EQ(batch_size, targets->shape().At(0)); + CHECK_EQ(batch_size, input_lengths->shape().At(0)); + CHECK_EQ(batch_size, target_lengths->shape().At(0)); + CHECK_GE(blank, 0); + CHECK_LT(blank, num_labels); + const int64_t max_input_length = log_probs->shape().At(0); + const int64_t max_target_length = targets->shape().At(1); + NdIndexOffsetHelper<int64_t, 3> input_helper(max_input_length, batch_size, num_labels); + NdIndexOffsetHelper<int64_t, 3> beta_helper(batch_size, max_input_length, + 2 * max_target_length + 1); + T* grad_ptr = grad->mut_dptr<T>(); + T* beta_ptr = tmp_buffer->mut_dptr<T>(); + CtcLossKernelUtil<device_type, T, IDX>::CtcLossBackward( + ctx->device_ctx(), grad_out_ptr, loss_ptr, alpha_ptr, log_probs_ptr, targets_ptr, + input_lengths_ptr, target_lengths_ptr, beta_ptr, grad_ptr, input_helper, beta_helper, + batch_size, max_input_length, max_target_length, num_labels, blank, zero_infinity); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_CTC_LOSS_BACKWARD_KERNEL(device, dtype, idx_dtype) \ + REGISTER_USER_KERNEL("ctc_loss_grad") \ + .SetCreateFn< \ + CtcLossGradKernel<device, OF_PP_PAIR_FIRST(dtype), OF_PP_PAIR_FIRST(idx_dtype)>>() \ + .SetIsMatchedHob( \ + (user_op::HobDeviceTag() == device) \ + & (user_op::HobDataType("log_probs", 0) == OF_PP_PAIR_SECOND(dtype)) \ + & (user_op::HobDataType("input_lengths", 0) == OF_PP_PAIR_SECOND(idx_dtype))) \ + .SetInferTmpSizeFn([](user_op::InferContext* ctx) { \ + const Shape* log_probs_shape = ctx->Shape4ArgNameAndIndex("log_probs", 0); \ + const Shape* targets_shape = ctx->Shape4ArgNameAndIndex("targets", 0); \ + int64_t elem_cnt = \ + log_probs_shape->At(1) * log_probs_shape->At(0) * (2 * targets_shape->At(1) + 1); \ + return elem_cnt * sizeof(OF_PP_PAIR_FIRST(dtype)); \ + }); + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_CTC_LOSS_BACKWARD_KERNEL, DEVICE_TYPE_SEQ, + FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) + +} // namespace oneflow diff --git a/oneflow/user/kernels/ctc_loss_kernel_util.cpp b/oneflow/user/kernels/ctc_loss_kernel_util.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0268880d0712098424fa5ee10baa0b9dc9e7c4c7 --- /dev/null +++ b/oneflow/user/kernels/ctc_loss_kernel_util.cpp @@ -0,0 +1,232 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/user/kernels/ctc_loss_kernel_util.h" + +namespace oneflow { + +int get_target_prime(const int* targets_ptr, int64_t max_target_length, int64_t b, int64_t s, + int blank) { + if (s % 2 == 0) { + return blank; + } else { + int64_t idx = b * max_target_length + s / 2; + return targets_ptr[idx]; + } +} + +template<typename T, typename IDX> +struct CtcLossKernelUtil<DeviceType::kCPU, T, IDX> final { + static void CtcLossForward(DeviceCtx* ctx, const T* log_probs_ptr, const int* targets_ptr, + const IDX* input_lengths_ptr, const IDX* target_lengths_ptr, + T* alpha_ptr, T* loss_ptr, + NdIndexOffsetHelper<int64_t, 3>& input_helper, + NdIndexOffsetHelper<int64_t, 3>& alpha_helper, + const int64_t batch_size, const int64_t max_input_length, + const int64_t max_target_length, const int blank); + + static void CtcLossBackward(DeviceCtx* ctx, const T* grad_out_ptr, const T* loss_ptr, + const T* alpha_ptr, const T* log_probs_ptr, const int* targets_ptr, + const IDX* input_lengths_ptr, const IDX* target_lengths_ptr, + T* beta_ptr, T* grad_ptr, + NdIndexOffsetHelper<int64_t, 3>& input_helper, + NdIndexOffsetHelper<int64_t, 3>& beta_helper, + const int64_t batch_size, const int64_t max_input_length, + const int64_t max_target_length, const int64_t num_labels, + const int blank, const bool zero_infinity); +}; + +template<typename T, typename IDX> +void CtcLossKernelUtil<DeviceType::kCPU, T, IDX>::CtcLossForward( + DeviceCtx* ctx, const T* log_probs_ptr, const int* targets_ptr, const IDX* input_lengths_ptr, + const IDX* target_lengths_ptr, T* alpha_ptr, T* loss_ptr, + NdIndexOffsetHelper<int64_t, 3>& input_helper, NdIndexOffsetHelper<int64_t, 3>& alpha_helper, + const int64_t batch_size, const int64_t max_input_length, const int64_t max_target_length, + const int blank) { + constexpr T neginf = -std::numeric_limits<T>::infinity(); + FOR_RANGE(int64_t, b, 0, batch_size) { + CHECK_GE(max_input_length, input_lengths_ptr[b]); + CHECK_GE(max_target_length, target_lengths_ptr[b]); + } + FOR_RANGE(int32_t, b, 0, batch_size) { + IDX input_length = input_lengths_ptr[b]; + IDX target_length = target_lengths_ptr[b]; + + int64_t alpha_idx = alpha_helper.NdIndexToOffset(b, 0, 0); + for (IDX s = 0; s < 2 * target_length + 1; s++) { alpha_ptr[alpha_idx + s] = neginf; } + alpha_ptr[alpha_idx] = log_probs_ptr[input_helper.NdIndexToOffset(0, b, blank)]; + if (target_length > 0) { + int target = get_target_prime(targets_ptr, max_target_length, b, 1, blank); + alpha_ptr[alpha_idx + 1] = log_probs_ptr[input_helper.NdIndexToOffset(0, b, target)]; + } + + for (IDX t = 1; t < input_length; t++) { + for (IDX s = 0; s < 2 * target_length + 1; s++) { + int current_target_prime = get_target_prime(targets_ptr, max_target_length, b, s, blank); + T la1 = alpha_ptr[alpha_helper.NdIndexToOffset(b, t - 1, s)]; + T la2, la3, lamax = la1; + if (s > 0) { + la2 = alpha_ptr[alpha_helper.NdIndexToOffset(b, t - 1, s - 1)]; + if (la2 > lamax) lamax = la2; + } else { + la2 = neginf; + } + if ((s > 1) + && (get_target_prime(targets_ptr, max_target_length, b, s - 2, blank) + != current_target_prime)) { + la3 = alpha_ptr[alpha_helper.NdIndexToOffset(b, t - 1, s - 2)]; + if (la3 > lamax) lamax = la3; + } else { + la3 = neginf; + } + if (lamax == neginf) lamax = 0; + + int64_t idx_t_s = alpha_helper.NdIndexToOffset(b, t, s); + alpha_ptr[idx_t_s] = + std::log(std::exp(la1 - lamax) + std::exp(la2 - lamax) + std::exp(la3 - lamax)) + lamax + + log_probs_ptr[input_helper.NdIndexToOffset(t, b, current_target_prime)]; + } + } + + if (target_length == 0) { + int64_t idx = alpha_helper.NdIndexToOffset(b, input_length - 1, 0); + loss_ptr[b] = -alpha_ptr[idx]; + } else { + int64_t idx1 = alpha_helper.NdIndexToOffset(b, input_length - 1, target_length * 2); + int64_t idx2 = alpha_helper.NdIndexToOffset(b, input_length - 1, target_length * 2 - 1); + T l1 = alpha_ptr[idx1]; + T l2 = alpha_ptr[idx2]; + T m = std::max(l1, l2); + m = ((m == neginf) ? 0 : m); + T log_likelihood = std::log(std::exp(l1 - m) + std::exp(l2 - m)) + m; + loss_ptr[b] = -log_likelihood; + } + } +} + +template<typename T, typename IDX> +void CtcLossKernelUtil<DeviceType::kCPU, T, IDX>::CtcLossBackward( + DeviceCtx* ctx, const T* grad_out_ptr, const T* loss_ptr, const T* alpha_ptr, + const T* log_probs_ptr, const int* targets_ptr, const IDX* input_lengths_ptr, + const IDX* target_lengths_ptr, T* beta_ptr, T* grad_ptr, + NdIndexOffsetHelper<int64_t, 3>& input_helper, NdIndexOffsetHelper<int64_t, 3>& beta_helper, + const int64_t batch_size, const int64_t max_input_length, const int64_t max_target_length, + const int64_t num_labels, const int blank, const bool zero_infinity) { + constexpr T neginf = -std::numeric_limits<T>::infinity(); + int64_t elem_cnt = max_input_length * batch_size * num_labels; + FOR_RANGE(int64_t, i, 0, elem_cnt) { grad_ptr[i] = neginf; } + + FOR_RANGE(int64_t, b, 0, batch_size) { + IDX input_length = input_lengths_ptr[b]; + IDX target_length = target_lengths_ptr[b]; + T nll = loss_ptr[b]; + if (zero_infinity && nll == std::numeric_limits<T>::infinity()) { + for (IDX t = 0; t < max_input_length; t++) { + for (IDX c = 0; c < num_labels; c++) { + grad_ptr[input_helper.NdIndexToOffset(t, b, c)] = 0; + } + } + continue; + } + + if (input_length > 0) { + int64_t beta_idx = beta_helper.NdIndexToOffset(b, input_length - 1, 0); + for (IDX s = 0; s < 2 * target_length + 1; s++) { beta_ptr[beta_idx + s] = neginf; } + beta_ptr[beta_idx + 2 * target_length] = + log_probs_ptr[input_helper.NdIndexToOffset(input_length - 1, b, blank)]; + grad_ptr[input_helper.NdIndexToOffset(input_length - 1, b, blank)] = + alpha_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length)] + + beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length)]; + + if (target_length > 0) { + int target = + get_target_prime(targets_ptr, max_target_length, b, 2 * target_length - 1, blank); + beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length - 1)] = + log_probs_ptr[input_helper.NdIndexToOffset(input_length - 1, b, target)]; + grad_ptr[input_helper.NdIndexToOffset(input_length - 1, b, target)] = + alpha_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length - 1)] + + beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length - 1)]; + } + } + + for (IDX t = input_length - 2; t >= 0; t--) { + for (IDX s = 2 * target_length; s >= 0; s--) { + int current_target_prime = get_target_prime(targets_ptr, max_target_length, b, s, blank); + T lb1 = beta_ptr[beta_helper.NdIndexToOffset(b, t + 1, s)]; + T lb2, lb3, lbmax = lb1; + + if (s < 2 * target_length) { + lb2 = beta_ptr[beta_helper.NdIndexToOffset(b, t + 1, s + 1)]; + if (lb2 > lbmax) lbmax = lb2; + } else { + lb2 = neginf; + } + + if ((s < 2 * target_length - 1) + && (get_target_prime(targets_ptr, max_target_length, b, s + 2, blank) + != current_target_prime)) { + lb3 = beta_ptr[beta_helper.NdIndexToOffset(b, t + 1, s + 2)]; + if (lb3 > lbmax) lbmax = lb3; + } else { + lb3 = neginf; + } + if (lbmax == neginf) lbmax = 0; + + int64_t idx_t_s = beta_helper.NdIndexToOffset(b, t, s); + beta_ptr[idx_t_s] = + std::log(std::exp(lb1 - lbmax) + std::exp(lb2 - lbmax) + std::exp(lb3 - lbmax)) + lbmax + + log_probs_ptr[input_helper.NdIndexToOffset(t, b, current_target_prime)]; + + T log_alpha_beta = alpha_ptr[idx_t_s] + beta_ptr[idx_t_s]; + T& lcab = grad_ptr[input_helper.NdIndexToOffset(t, b, current_target_prime)]; + if (lcab == neginf) { + lcab = log_alpha_beta; + } else { + T m = std::max(lcab, log_alpha_beta); + lcab = std::log(std::exp(lcab - m) + std::exp(log_alpha_beta - m)) + m; + } + } + } + + for (int32_t t = 0; t < input_length; t++) { + for (int64_t c = 0; c < num_labels; c++) { + T& res = grad_ptr[input_helper.NdIndexToOffset(t, b, c)]; + T lp = log_probs_ptr[input_helper.NdIndexToOffset(t, b, c)]; + res = (std::exp(lp) - std::exp(res + nll - lp)) * grad_out_ptr[b]; + } + } + + // zero the remainder + if (input_length < max_input_length) { + for (int64_t t = input_length; t < max_input_length; t++) { + for (int64_t c = 0; c < num_labels; c++) { + int64_t grad_idx = input_helper.NdIndexToOffset(t, b, c); + grad_ptr[grad_idx] = 0; + } + } + } + } +} + +#define INSTANTIATE_CTC_LOSS_KERNEL_UTIL_CPU(device_type_v, log_probs_dtype_pair, \ + input_lengths_dtype_pair) \ + template struct CtcLossKernelUtil<device_type_v, OF_PP_PAIR_FIRST(log_probs_dtype_pair), \ + OF_PP_PAIR_FIRST(input_lengths_dtype_pair)>; + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CTC_LOSS_KERNEL_UTIL_CPU, (DeviceType::kCPU), + FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) +#undef INSTANTIATE_CTC_LOSS_KERNEL_UTIL_CPU + +} // namespace oneflow diff --git a/oneflow/user/kernels/ctc_loss_kernel_util.cu b/oneflow/user/kernels/ctc_loss_kernel_util.cu new file mode 100644 index 0000000000000000000000000000000000000000..03538a623ca7041d73cc0d10063f9c42f4d17bef --- /dev/null +++ b/oneflow/user/kernels/ctc_loss_kernel_util.cu @@ -0,0 +1,267 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/user/kernels/ctc_loss_kernel_util.h" + +namespace oneflow { + +namespace { + +__device__ __inline__ static int get_target_prime(const int* targets_ptr, int64_t max_target_length, + int64_t b, int64_t s, int blank) { + if (s % 2 == 0) { + return blank; + } else { + int64_t idx = b * max_target_length + s / 2; + return targets_ptr[idx]; + } +} + +template<typename T, typename IDX> +__global__ void CtcLossGpu(const T* log_probs_ptr, const int* targets_ptr, + const IDX* input_lengths_ptr, const IDX* target_lengths_ptr, + T* alpha_ptr, T* loss_ptr, NdIndexOffsetHelper<int64_t, 3> input_helper, + NdIndexOffsetHelper<int64_t, 3> alpha_helper, const int64_t batch_size, + const int64_t max_input_length, const int64_t max_target_length, + const int blank) { + constexpr T neginf = -INFINITY; + const int32_t bid = blockIdx.x; + const int32_t tid = threadIdx.x; + for (int64_t b = bid; b < batch_size; b += gridDim.x) { + if (tid == 0) { + if (input_lengths_ptr[b] > max_input_length) __trap(); + if (target_lengths_ptr[b] > max_target_length) __trap(); + } + } + for (int64_t b = bid; b < batch_size; b += gridDim.x) { + IDX input_length = input_lengths_ptr[b]; + IDX target_length = target_lengths_ptr[b]; + + for (IDX s = tid; s < 2 * target_length + 1; s += blockDim.x) { + alpha_ptr[alpha_helper.NdIndexToOffset(b, 0, s)] = neginf; + } + if (tid == 0) { + alpha_ptr[alpha_helper.NdIndexToOffset(b, 0, 0)] = + log_probs_ptr[input_helper.NdIndexToOffset(0, b, blank)]; + if (target_length > 0) { + int target = get_target_prime(targets_ptr, max_target_length, b, 1, blank); + alpha_ptr[alpha_helper.NdIndexToOffset(b, 0, 1)] = + log_probs_ptr[input_helper.NdIndexToOffset(0, b, target)]; + } + } + __syncthreads(); + for (IDX t = 1; t < input_length; t++) { + for (IDX s = tid; s < 2 * target_length + 1; s += blockDim.x) { + int current_target_prime = get_target_prime(targets_ptr, max_target_length, b, s, blank); + T la1 = alpha_ptr[alpha_helper.NdIndexToOffset(b, t - 1, s)]; + T la2, la3, lamax = la1; + if (s > 0) { + la2 = alpha_ptr[alpha_helper.NdIndexToOffset(b, t - 1, s - 1)]; + if (la2 > lamax) lamax = la2; + } else { + la2 = neginf; + } + if ((s > 1) + && (get_target_prime(targets_ptr, max_target_length, b, s - 2, blank) + != current_target_prime)) { + la3 = alpha_ptr[alpha_helper.NdIndexToOffset(b, t - 1, s - 2)]; + if (la3 > lamax) lamax = la3; + } else { + la3 = neginf; + } + if (lamax == neginf) lamax = 0; + + int64_t idx_t_s = alpha_helper.NdIndexToOffset(b, t, s); + alpha_ptr[idx_t_s] = + log(exp(la1 - lamax) + exp(la2 - lamax) + exp(la3 - lamax)) + lamax + + log_probs_ptr[input_helper.NdIndexToOffset(t, b, current_target_prime)]; + } + __syncthreads(); + } + if (tid == 0) { + if (target_length == 0) { + int64_t idx = alpha_helper.NdIndexToOffset(b, input_length - 1, 0); + loss_ptr[b] = -alpha_ptr[idx]; + } else { + int64_t idx1 = alpha_helper.NdIndexToOffset(b, input_length - 1, target_length * 2); + int64_t idx2 = alpha_helper.NdIndexToOffset(b, input_length - 1, target_length * 2 - 1); + T l1 = alpha_ptr[idx1]; + T l2 = alpha_ptr[idx2]; + T m = max(l1, l2); + m = ((m == neginf) ? 0 : m); + T log_likelihood = log(exp(l1 - m) + exp(l2 - m)) + m; + loss_ptr[b] = -log_likelihood; + } + } + } +} + +template<typename T, typename IDX> +__global__ void CtcLossGradGpu(const T* grad_out_ptr, const T* loss_ptr, const T* alpha_ptr, + const T* log_probs_ptr, const int* targets_ptr, + const IDX* input_lengths_ptr, const IDX* target_lengths_ptr, + T* beta_ptr, T* grad_ptr, + NdIndexOffsetHelper<int64_t, 3> input_helper, + NdIndexOffsetHelper<int64_t, 3> beta_helper, + const int64_t batch_size, const int64_t max_input_length, + const int64_t max_target_length, const int64_t num_labels, + const int blank, const bool zero_infinity) { + constexpr T neginf = -INFINITY; + const int32_t bid = blockIdx.x; + const int32_t tid = threadIdx.x; + + for (int64_t b = bid; b < batch_size; b += gridDim.x) { + IDX input_length = input_lengths_ptr[b]; + IDX target_length = target_lengths_ptr[b]; + T nll = loss_ptr[b]; + if (zero_infinity && nll == INFINITY) { + for (IDX t = tid; t < max_input_length; t += blockDim.x) { + for (IDX c = 0; c < num_labels; c++) { + grad_ptr[input_helper.NdIndexToOffset(t, b, c)] = 0; + } + } + __syncthreads(); + continue; + } + + if (input_length > 0) { + for (IDX s = tid; s < 2 * target_length + 1; s += blockDim.x) { + beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, s)] = neginf; + } + if (tid == 0) { + beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length)] = + log_probs_ptr[input_helper.NdIndexToOffset(input_length - 1, b, blank)]; + if (target_length > 0) { + int target = + get_target_prime(targets_ptr, max_target_length, b, 2 * target_length - 1, blank); + beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length - 1)] = + log_probs_ptr[input_helper.NdIndexToOffset(input_length - 1, b, target)]; + } + } + __syncthreads(); + } + for (IDX t = input_length - 2; t >= 0; t--) { + for (IDX s = tid; s < 2 * target_length + 1; s += blockDim.x) { + int current_target_prime = get_target_prime(targets_ptr, max_target_length, b, s, blank); + T lb1 = beta_ptr[beta_helper.NdIndexToOffset(b, t + 1, s)]; + T lb2, lb3, lbmax = lb1; + if (s < 2 * target_length) { + lb2 = beta_ptr[beta_helper.NdIndexToOffset(b, t + 1, s + 1)]; + if (lb2 > lbmax) lbmax = lb2; + } else { + lb2 = neginf; + } + if ((s < 2 * target_length - 1) + && (get_target_prime(targets_ptr, max_target_length, b, s + 2, blank) + != current_target_prime)) { + lb3 = beta_ptr[beta_helper.NdIndexToOffset(b, t + 1, s + 2)]; + if (lb3 > lbmax) lbmax = lb3; + } else { + lb3 = neginf; + } + if (lbmax == neginf) lbmax = 0; + + int64_t idx_t_s = beta_helper.NdIndexToOffset(b, t, s); + beta_ptr[idx_t_s] = + log(exp(lb1 - lbmax) + exp(lb2 - lbmax) + exp(lb3 - lbmax)) + lbmax + + log_probs_ptr[input_helper.NdIndexToOffset(t, b, current_target_prime)]; + } + __syncthreads(); + } + for (IDX t = tid; t < max_input_length; t += blockDim.x) { + for (IDX c = 0; c < num_labels; c++) { + grad_ptr[input_helper.NdIndexToOffset(t, b, c)] = t < input_length ? neginf : 0; + } + } + __syncthreads(); + if (tid == 0) { + grad_ptr[input_helper.NdIndexToOffset(input_length - 1, b, blank)] = + alpha_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length)] + + beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length)]; + if (target_length > 0) { + int target = + get_target_prime(targets_ptr, max_target_length, b, 2 * target_length - 1, blank); + grad_ptr[input_helper.NdIndexToOffset(input_length - 1, b, target)] = + alpha_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length - 1)] + + beta_ptr[beta_helper.NdIndexToOffset(b, input_length - 1, 2 * target_length - 1)]; + } + } + __syncthreads(); + for (IDX t = tid; t < input_length; t += blockDim.x) { + for (IDX s = 0; (t < input_length - 1) && (s < 2 * target_length + 1); s += 1) { + int current_target_prime = get_target_prime(targets_ptr, max_target_length, b, s, blank); + int64_t idx_t_s = beta_helper.NdIndexToOffset(b, t, s); + T log_alpha_beta = alpha_ptr[idx_t_s] + beta_ptr[idx_t_s]; + T& lcab = grad_ptr[input_helper.NdIndexToOffset(t, b, current_target_prime)]; + if (lcab == neginf) { + lcab = log_alpha_beta; + } else { + T m = max(lcab, log_alpha_beta); + lcab = log(exp(lcab - m) + exp(log_alpha_beta - m)) + m; + } + } + for (int32_t c = 0; c < num_labels; c++) { + T& res = grad_ptr[input_helper.NdIndexToOffset(t, b, c)]; + T lp = log_probs_ptr[input_helper.NdIndexToOffset(t, b, c)]; + res = (exp(lp) - exp(res + nll - lp)) * grad_out_ptr[b]; + } + } + } +} + +} // namespace + +template<typename T, typename IDX> +struct CtcLossKernelUtil<DeviceType::kGPU, T, IDX> { + static void CtcLossForward(DeviceCtx* ctx, const T* log_probs_ptr, const int* targets_ptr, + const IDX* input_lengths_ptr, const IDX* target_lengths_ptr, + T* alpha_ptr, T* loss_ptr, + NdIndexOffsetHelper<int64_t, 3>& input_helper, + NdIndexOffsetHelper<int64_t, 3>& alpha_helper, + const int64_t batch_size, const int64_t max_input_length, + const int64_t max_target_length, const int blank) { + int32_t thread_num = batch_size * kCudaThreadsNumPerBlock; + RUN_CUDA_KERNEL((CtcLossGpu<T, IDX>), ctx, thread_num, log_probs_ptr, targets_ptr, + input_lengths_ptr, target_lengths_ptr, alpha_ptr, loss_ptr, input_helper, + alpha_helper, batch_size, max_input_length, max_target_length, blank); + } + + static void CtcLossBackward(DeviceCtx* ctx, const T* grad_out_ptr, const T* loss_ptr, + const T* alpha_ptr, const T* log_probs_ptr, const int* targets_ptr, + const IDX* input_lengths_ptr, const IDX* target_lengths_ptr, + T* beta_ptr, T* grad_ptr, + NdIndexOffsetHelper<int64_t, 3>& input_helper, + NdIndexOffsetHelper<int64_t, 3>& beta_helper, + const int64_t batch_size, const int64_t max_input_length, + const int64_t max_target_length, const int64_t num_labels, + const int blank, const bool zero_infinity) { + int32_t thread_num = batch_size * kCudaThreadsNumPerBlock; + RUN_CUDA_KERNEL((CtcLossGradGpu<T, IDX>), ctx, thread_num, grad_out_ptr, loss_ptr, alpha_ptr, + log_probs_ptr, targets_ptr, input_lengths_ptr, target_lengths_ptr, beta_ptr, + grad_ptr, input_helper, beta_helper, batch_size, max_input_length, + max_target_length, num_labels, blank, zero_infinity); + } +}; + +#define INSTANTIATE_CTC_LOSS_KERNEL_UTIL_GPU(device_type_v, log_probs_dtype_pair, \ + input_lengths_dtype_pair) \ + template struct CtcLossKernelUtil<device_type_v, OF_PP_PAIR_FIRST(log_probs_dtype_pair), \ + OF_PP_PAIR_FIRST(input_lengths_dtype_pair)>; + +OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(INSTANTIATE_CTC_LOSS_KERNEL_UTIL_GPU, (DeviceType::kGPU), + FLOATING_DATA_TYPE_SEQ, INDEX_DATA_TYPE_SEQ) +#undef INSTANTIATE_CTC_LOSS_KERNEL_UTIL_GPU + +} // namespace oneflow diff --git a/oneflow/user/kernels/ctc_loss_kernel_util.h b/oneflow/user/kernels/ctc_loss_kernel_util.h new file mode 100644 index 0000000000000000000000000000000000000000..279ce3fb5d0813511d7b933efdfeaf168fe15a74 --- /dev/null +++ b/oneflow/user/kernels/ctc_loss_kernel_util.h @@ -0,0 +1,47 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_USER_KERNELS_CTC_LOSS_KERNEL_UTIL_H_ +#define ONEFLOW_USER_KERNELS_CTC_LOSS_KERNEL_UTIL_H_ + +#include "oneflow/core/device/device_context.h" +#include "oneflow/core/common/nd_index_offset_helper.h" + +namespace oneflow { + +template<DeviceType device_type, typename T, typename IDX> +struct CtcLossKernelUtil final { + static void CtcLossForward(DeviceCtx* ctx, const T* log_probs_ptr, const int* targets_ptr, + const IDX* input_lengths_ptr, const IDX* target_lengths_ptr, + T* alpha_ptr, T* loss_ptr, + NdIndexOffsetHelper<int64_t, 3>& input_helper, + NdIndexOffsetHelper<int64_t, 3>& alpha_helper, + const int64_t batch_size, const int64_t max_input_length, + const int64_t max_target_length, const int blank); + + static void CtcLossBackward(DeviceCtx* ctx, const T* grad_out_ptr, const T* loss_ptr, + const T* alpha_ptr, const T* log_probs_ptr, const int* targets_ptr, + const IDX* input_lengths_ptr, const IDX* target_lengths_ptr, + T* beta_ptr, T* grad_ptr, + NdIndexOffsetHelper<int64_t, 3>& input_helper, + NdIndexOffsetHelper<int64_t, 3>& beta_helper, + const int64_t batch_size, const int64_t max_input_length, + const int64_t max_target_length, const int64_t num_labels, + const int blank, const bool zero_infinity); +}; + +} // namespace oneflow + +#endif // ONEFLOW_USER_KERNELS_CTC_LOSS_KERNEL_UTIL_H_ diff --git a/oneflow/user/ops/ctc_loss_op.cpp b/oneflow/user/ops/ctc_loss_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bd2881c33e4b421f7bd31684bc7e4356273d8d8a --- /dev/null +++ b/oneflow/user/ops/ctc_loss_op.cpp @@ -0,0 +1,134 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" + +namespace oneflow { + +REGISTER_USER_OP("ctc_loss") + .Input("log_probs") + .Input("targets") + .Input("input_lengths") + .Input("target_lengths") + .Output("loss") + .Output("alpha") // 'alpha' is just for compute log_probs's grad, alpha's grad will be ignored + .Attr<int>("blank") + .Attr<bool>("zero_infinity") + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + const user_op::TensorDesc* log_probs = ctx->TensorDesc4ArgNameAndIndex("log_probs", 0); + const user_op::TensorDesc* targets = ctx->TensorDesc4ArgNameAndIndex("targets", 0); + const user_op::TensorDesc* input_lengths = + ctx->TensorDesc4ArgNameAndIndex("input_lengths", 0); + const user_op::TensorDesc* target_lengths = + ctx->TensorDesc4ArgNameAndIndex("target_lengths", 0); + const int64_t batch_size = log_probs->shape().At(1); + CHECK_EQ_OR_RETURN(batch_size, targets->shape().At(0)); + CHECK_EQ_OR_RETURN(batch_size, input_lengths->shape().At(0)); + CHECK_EQ_OR_RETURN(batch_size, target_lengths->shape().At(0)); + CHECK_GE_OR_RETURN(ctx->Attr<int>("blank"), 0); + *ctx->Dtype4ArgNameAndIndex("loss", 0) = *ctx->Dtype4ArgNameAndIndex("log_probs", 0); + *ctx->Shape4ArgNameAndIndex("loss", 0) = Shape({batch_size}); + *ctx->Dtype4ArgNameAndIndex("alpha", 0) = *ctx->Dtype4ArgNameAndIndex("log_probs", 0); + *ctx->Shape4ArgNameAndIndex("alpha", 0) = + Shape({batch_size, log_probs->shape().At(0), 2 * targets->shape().At(1) + 1}); + return Maybe<void>::Ok(); + }) + .SetBatchAxisInferFn([](user_op::BatchAxisContext* ctx) -> Maybe<void> { + *ctx->BatchAxis4ArgNameAndIndex("loss", 0) = + *ctx->BatchAxis4ArgNameAndIndex("input_lengths", 0); + *ctx->BatchAxis4ArgNameAndIndex("alpha", 0) = + *ctx->BatchAxis4ArgNameAndIndex("input_lengths", 0); + return Maybe<void>::Ok(); + }) + .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { + ctx->NewBuilder() + .Split(user_op::OpArg("log_probs", 0), 1) // `log_probs` batch axis is 1 + .Split(user_op::OpArg("targets", 0), 0) + .Split(user_op::OpArg("input_lengths", 0), 0) + .Split(user_op::OpArg("target_lengths", 0), 0) + .Split(user_op::OpArg("loss", 0), 0) + .Split(user_op::OpArg("alpha", 0), 0) + .Build(); + return Maybe<void>::Ok(); + }); + +REGISTER_USER_OP("ctc_loss_grad") + .Input("grad_out") + .Input("log_probs") + .Input("targets") + .Input("input_lengths") + .Input("target_lengths") + .Input("loss") + .Input("alpha") + .Output("grad") + .Attr<int>("blank") + .Attr<bool>("zero_infinity") + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> { + const user_op::TensorDesc* log_probs = ctx->TensorDesc4ArgNameAndIndex("log_probs", 0); + const user_op::TensorDesc* targets = ctx->TensorDesc4ArgNameAndIndex("targets", 0); + const user_op::TensorDesc* input_lengths = + ctx->TensorDesc4ArgNameAndIndex("input_lengths", 0); + const user_op::TensorDesc* target_lengths = + ctx->TensorDesc4ArgNameAndIndex("target_lengths", 0); + const int64_t batch_size = log_probs->shape().At(1); + CHECK_EQ_OR_RETURN(batch_size, targets->shape().At(0)); + CHECK_EQ_OR_RETURN(batch_size, input_lengths->shape().At(0)); + CHECK_EQ_OR_RETURN(batch_size, target_lengths->shape().At(0)); + CHECK_GE_OR_RETURN(ctx->Attr<int>("blank"), 0); + *ctx->Dtype4ArgNameAndIndex("grad", 0) = *ctx->Dtype4ArgNameAndIndex("log_probs", 0); + *ctx->Shape4ArgNameAndIndex("grad", 0) = log_probs->shape(); + return Maybe<void>::Ok(); + }) + .SetBatchAxisInferFn([](user_op::BatchAxisContext* ctx) -> Maybe<void> { + *ctx->BatchAxis4ArgNameAndIndex("grad", 0) = *ctx->BatchAxis4ArgNameAndIndex("log_probs", 0); + return Maybe<void>::Ok(); + }) + .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { + ctx->NewBuilder() + .Split(user_op::OpArg("grad_out", 0), 0) + .Split(user_op::OpArg("log_probs", 0), 1) // `log_probs` batch axis is 1 + .Split(user_op::OpArg("targets", 0), 0) + .Split(user_op::OpArg("input_lengths", 0), 0) + .Split(user_op::OpArg("target_lengths", 0), 0) + .Split(user_op::OpArg("loss", 0), 0) + .Split(user_op::OpArg("alpha", 0), 0) + .Split(user_op::OpArg("grad", 0), 1) + .Build(); + return Maybe<void>::Ok(); + }); + +REGISTER_USER_OP_GRAD("ctc_loss").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { + const auto ctc_loss_grad_op_name = ctx->FwOp().op_name() + "_grad"; + ctx->DefineOp(ctc_loss_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { + return builder.OpTypeName("ctc_loss_grad") + .InputBind("grad_out", ctx->FwOp().output_grad("loss", 0)) + .InputBind("log_probs", ctx->FwOp().input("log_probs", 0)) + .InputBind("targets", ctx->FwOp().input("targets", 0)) + .InputBind("input_lengths", ctx->FwOp().input("input_lengths", 0)) + .InputBind("target_lengths", ctx->FwOp().input("target_lengths", 0)) + .InputBind("loss", ctx->FwOp().output("loss", 0)) + .InputBind("alpha", ctx->FwOp().output("alpha", 0)) + .Attr("blank", ctx->FwOp().attr<int>("blank")) + .Attr("zero_infinity", ctx->FwOp().attr<bool>("zero_infinity")) + .Output("grad") + .Build(); + }); + ctx->FwOp().InputGradBind(user_op::OpArg("log_probs", 0), + [&ctx, &ctc_loss_grad_op_name]() -> const std::string& { + return ctx->GetOp(ctc_loss_grad_op_name).output("grad", 0); + }); +}); + +} // namespace oneflow