From 5af1c835339d739d4a83417e543b96c41e207f67 Mon Sep 17 00:00:00 2001 From: Yinggang Wang <wyg19970408@gmail.com> Date: Wed, 21 Jul 2021 03:36:26 -0500 Subject: [PATCH] Remove inplace broadcast_add (#5551) * fix(*): remove inplace broadcast_add * fix(BroadcastLike): fix axes bug --- oneflow/core/functional/functional_api.yaml | 2 +- .../core/functional/impl/binary_functor.cpp | 2 +- oneflow/python/nn/modules/broadcast_like.py | 28 +++++++++++++++++-- oneflow/python/nn/modules/math_ops.py | 10 +++---- 4 files changed, 31 insertions(+), 11 deletions(-) diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index f24b6a5dc..9e3e7920b 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -36,7 +36,7 @@ bind_python: True - name: "broadcast_add" - signature: "Tensor BroadcastAdd(Tensor x, Tensor y, *, Bool inplace=False)" + signature: "Tensor BroadcastAdd(Tensor x, Tensor y)" bind_python: True - name: "sub_scalar_by_tensor" diff --git a/oneflow/core/functional/impl/binary_functor.cpp b/oneflow/core/functional/impl/binary_functor.cpp index 31606d46d..b8640de3a 100644 --- a/oneflow/core/functional/impl/binary_functor.cpp +++ b/oneflow/core/functional/impl/binary_functor.cpp @@ -50,7 +50,7 @@ class PowFunctor : public BinaryFunctor { } }; -class BroadcastAddFunctor : public InplaceableBinaryFunctor { +class BroadcastAddFunctor : public BinaryFunctor { public: BroadcastAddFunctor() { op_ = CHECK_JUST(one::OpBuilder("broadcast_add").Input("x").Input("y").Output("z").Build()); diff --git a/oneflow/python/nn/modules/broadcast_like.py b/oneflow/python/nn/modules/broadcast_like.py index c24846d33..4a59a2ec3 100644 --- a/oneflow/python/nn/modules/broadcast_like.py +++ b/oneflow/python/nn/modules/broadcast_like.py @@ -13,21 +13,43 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ + +from typing import Optional, Sequence + import oneflow as flow from oneflow.python.nn.module import Module from oneflow.python.oneflow_export import oneflow_export, experimental_api +def _calc_broadcast_axes(x, like_tensor): + num_prepend = len(like_tensor.shape) - len(x.shape) + prepend_shape = [1] * num_prepend + list(x.shape) + broadcast_axes = [x for x in range(num_prepend)] + for i in range(num_prepend, len(prepend_shape)): + if prepend_shape[i] != like_tensor.shape[i]: + if prepend_shape[i] != 1: + raise RuntimeError( + f"output with shape {x.shape} doesn't match the broadcast shape {like_tensor.shape}" + ) + else: + broadcast_axes.append(i) + return tuple(broadcast_axes) + + class BroadCastLike(Module): - def __init__(self, broadcast_axes: None) -> None: + def __init__(self, broadcast_axes: Optional[Sequence] = None) -> None: super().__init__() self.broadcast_axes = broadcast_axes def forward(self, x, like_tensor): - return flow.F.broadcast_like(x, like_tensor, broadcast_axes=self.broadcast_axes) + if self.broadcast_axes is None: + broadcast_axes = _calc_broadcast_axes(x, like_tensor) + else: + broadcast_axes = self.broadcast_axes + return flow.F.broadcast_like(x, like_tensor, broadcast_axes=broadcast_axes) @oneflow_export("broadcast_like") @experimental_api -def broadcast_like_op(x, like_tensor, broadcast_axes: None): +def broadcast_like_op(x, like_tensor, broadcast_axes: Optional[Sequence] = None): return BroadCastLike(broadcast_axes=broadcast_axes)(x, like_tensor) diff --git a/oneflow/python/nn/modules/math_ops.py b/oneflow/python/nn/modules/math_ops.py index b7ad7edeb..b6acf8b53 100644 --- a/oneflow/python/nn/modules/math_ops.py +++ b/oneflow/python/nn/modules/math_ops.py @@ -391,14 +391,11 @@ class ElementwiseAdd(Module): class BroadcastAdd(Module): - def __init__(self, inplace: bool = False) -> None: + def __init__(self) -> None: super().__init__() - self.inplace = inplace def forward(self, x, y): - if self.inplace: - _check_inplace_valid(x) - return flow.F.broadcast_add(x, y, self.inplace) + return flow.F.broadcast_add(x, y) @oneflow_export("add") @@ -474,7 +471,8 @@ def _add_inplace(x, y): elif y.shape == (1,): return ScalarAddByTensor(inplace=True)(x, y) else: - return BroadcastAdd(inplace=True)(x, y) + y = flow.experimental.broadcast_like(y, x) + return ElementwiseAdd(inplace=True)(x, y) class Asin(Module): -- GitLab