Debug dim scatter (#5371)
* startup of dev scatter ops * use dim scatter base class * refine(using binop to abstract scatter update and add * refine (use macros to implement kerenl class and functors) * refine(description for register scatter ops/kernels) * refine * add inplace ops * python wraper scatter_add inplace * dev inplace ops * refine dim_gather (using macros register mechanism) * add grad of scatter_add_like * refine (add src, like versions for scatter) * refine src/like tensor * gather refine(no need outplace/inplace versions) * reformat * refine * test case of dim scatter * test case for dim_scatter_add_like * 1n2d test case for dim_scatter_add_like * refine scatter sbp * fail to sccater_add_like on 1n2d * refing sbp * refine test case, unify add and update like ops * test case for scatter_add/update like ops finished * test cases for scatter ops * refine, merge test class * startup of api docs * add scatter api docs and assertion in python * fix make error but still segment fault * annotate sbp infer * rewrite scatter kernel logic * remove inplace proposal and fix macro name * remove outdated atomic add * move sbp infer * add const and throw error * add check * set grad op * add scatter scalar * add scatter scalar gpu kernel * add torch style backprop * add torch style backprop check * align with master * remove redundant sbp check * add test * add float16n register * fix sbp * fix sbp * add api doc * make format * add new line * refine * revert dim gather * extract dim_scatter_add * extracat scatter update ops * add add/update functor * rewrting by functors * refine * remove dim_gather_scatter_uitl.h * add blank line * refine macros for registering kerenls * refine dim_scatter_scalar files name * refine * refine register ops * refine * add F.dim_scatter_scalar * add scatter op * refine docstr * add scatter reduce arg * finally(!): a draft for scatter constitent with pytroch * change import package name * remmove lazy test and add scatter_add and scatter_mul * startup of scatter backward op * add backward for scatter * scatter ops backward finished * add scatter, scatter_add test cases * remove useless scatter_update_like * reformat * refine test cases * refine according to comments * revert op_exprt_helper * fixed index element * fix scatter update like expr for dim gather backward Co-authored-by:doombeaker <later@usopp.net> Co-authored-by:
oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Showing
- docs/source/oneflow.rst 2 additions, 0 deletionsdocs/source/oneflow.rst
- oneflow/core/autograd/gradient_funcs/dim_gather.cpp 1 addition, 1 deletiononeflow/core/autograd/gradient_funcs/dim_gather.cpp
- oneflow/core/autograd/gradient_funcs/dim_scatter.cpp 176 additions, 0 deletionsoneflow/core/autograd/gradient_funcs/dim_scatter.cpp
- oneflow/core/framework/op_expr_helper.cpp 1 addition, 1 deletiononeflow/core/framework/op_expr_helper.cpp
- oneflow/core/functional/functional_api.yaml 16 additions, 0 deletionsoneflow/core/functional/functional_api.yaml
- oneflow/core/functional/impl/array_functor.cpp 138 additions, 0 deletionsoneflow/core/functional/impl/array_functor.cpp
- oneflow/user/kernels/dim_gather_kernel_util.cpp 0 additions, 13 deletionsoneflow/user/kernels/dim_gather_kernel_util.cpp
- oneflow/user/kernels/dim_gather_kernel_util.cu 0 additions, 34 deletionsoneflow/user/kernels/dim_gather_kernel_util.cu
- oneflow/user/kernels/dim_gather_kernel_util.h 0 additions, 4 deletionsoneflow/user/kernels/dim_gather_kernel_util.h
- oneflow/user/kernels/dim_gather_kernels.cpp 0 additions, 57 deletionsoneflow/user/kernels/dim_gather_kernels.cpp
- oneflow/user/kernels/dim_scatter_kernel_util.cpp 39 additions, 0 deletionsoneflow/user/kernels/dim_scatter_kernel_util.cpp
- oneflow/user/kernels/dim_scatter_kernel_util.cu 66 additions, 0 deletionsoneflow/user/kernels/dim_scatter_kernel_util.cu
- oneflow/user/kernels/dim_scatter_kernel_util.h 100 additions, 0 deletionsoneflow/user/kernels/dim_scatter_kernel_util.h
- oneflow/user/kernels/dim_scatter_kernels.cpp 138 additions, 0 deletionsoneflow/user/kernels/dim_scatter_kernels.cpp
- oneflow/user/kernels/dim_scatter_scalar_kernel_util.cpp 37 additions, 0 deletionsoneflow/user/kernels/dim_scatter_scalar_kernel_util.cpp
- oneflow/user/kernels/dim_scatter_scalar_kernel_util.cu 50 additions, 0 deletionsoneflow/user/kernels/dim_scatter_scalar_kernel_util.cu
- oneflow/user/kernels/dim_scatter_scalar_kernel_util.h 97 additions, 0 deletionsoneflow/user/kernels/dim_scatter_scalar_kernel_util.h
- oneflow/user/kernels/dim_scatter_scalar_kernels.cpp 101 additions, 0 deletionsoneflow/user/kernels/dim_scatter_scalar_kernels.cpp
- oneflow/user/ops/dim_gather_op.cpp 4 additions, 93 deletionsoneflow/user/ops/dim_gather_op.cpp
- oneflow/user/ops/dim_scatter_ops.cpp 296 additions, 0 deletionsoneflow/user/ops/dim_scatter_ops.cpp
Please register or sign in to comment