diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index f24b6a5dc654af47e9f6a1371318b63c17c75433..9e3e7920bb237c567518dd6371abdbb3e6323297 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -36,7 +36,7 @@ bind_python: True - name: "broadcast_add" - signature: "Tensor BroadcastAdd(Tensor x, Tensor y, *, Bool inplace=False)" + signature: "Tensor BroadcastAdd(Tensor x, Tensor y)" bind_python: True - name: "sub_scalar_by_tensor" diff --git a/oneflow/core/functional/impl/binary_functor.cpp b/oneflow/core/functional/impl/binary_functor.cpp index 31606d46d224a2c32e18adbc1fb2ca94459ee26f..b8640de3a887ab8f56dc4c24c0e029402038ef13 100644 --- a/oneflow/core/functional/impl/binary_functor.cpp +++ b/oneflow/core/functional/impl/binary_functor.cpp @@ -50,7 +50,7 @@ class PowFunctor : public BinaryFunctor { } }; -class BroadcastAddFunctor : public InplaceableBinaryFunctor { +class BroadcastAddFunctor : public BinaryFunctor { public: BroadcastAddFunctor() { op_ = CHECK_JUST(one::OpBuilder("broadcast_add").Input("x").Input("y").Output("z").Build()); diff --git a/oneflow/python/nn/modules/broadcast_like.py b/oneflow/python/nn/modules/broadcast_like.py index c24846d3337aa23145297183c9e015df9f6486b6..4a59a2ec316d6f9ae3bc1f45d1d16f067f08e1f6 100644 --- a/oneflow/python/nn/modules/broadcast_like.py +++ b/oneflow/python/nn/modules/broadcast_like.py @@ -13,21 +13,43 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ + +from typing import Optional, Sequence + import oneflow as flow from oneflow.python.nn.module import Module from oneflow.python.oneflow_export import oneflow_export, experimental_api +def _calc_broadcast_axes(x, like_tensor): + num_prepend = len(like_tensor.shape) - len(x.shape) + prepend_shape = [1] * num_prepend + list(x.shape) + broadcast_axes = [x for x in range(num_prepend)] + for i in range(num_prepend, len(prepend_shape)): + if prepend_shape[i] != like_tensor.shape[i]: + if prepend_shape[i] != 1: + raise RuntimeError( + f"output with shape {x.shape} doesn't match the broadcast shape {like_tensor.shape}" + ) + else: + broadcast_axes.append(i) + return tuple(broadcast_axes) + + class BroadCastLike(Module): - def __init__(self, broadcast_axes: None) -> None: + def __init__(self, broadcast_axes: Optional[Sequence] = None) -> None: super().__init__() self.broadcast_axes = broadcast_axes def forward(self, x, like_tensor): - return flow.F.broadcast_like(x, like_tensor, broadcast_axes=self.broadcast_axes) + if self.broadcast_axes is None: + broadcast_axes = _calc_broadcast_axes(x, like_tensor) + else: + broadcast_axes = self.broadcast_axes + return flow.F.broadcast_like(x, like_tensor, broadcast_axes=broadcast_axes) @oneflow_export("broadcast_like") @experimental_api -def broadcast_like_op(x, like_tensor, broadcast_axes: None): +def broadcast_like_op(x, like_tensor, broadcast_axes: Optional[Sequence] = None): return BroadCastLike(broadcast_axes=broadcast_axes)(x, like_tensor) diff --git a/oneflow/python/nn/modules/math_ops.py b/oneflow/python/nn/modules/math_ops.py index b7ad7edeb4539c9a1f013ee6ef42094080d05b0b..b6acf8b536faa94b0ebed2e753fe0e7d22c2d50b 100644 --- a/oneflow/python/nn/modules/math_ops.py +++ b/oneflow/python/nn/modules/math_ops.py @@ -391,14 +391,11 @@ class ElementwiseAdd(Module): class BroadcastAdd(Module): - def __init__(self, inplace: bool = False) -> None: + def __init__(self) -> None: super().__init__() - self.inplace = inplace def forward(self, x, y): - if self.inplace: - _check_inplace_valid(x) - return flow.F.broadcast_add(x, y, self.inplace) + return flow.F.broadcast_add(x, y) @oneflow_export("add") @@ -474,7 +471,8 @@ def _add_inplace(x, y): elif y.shape == (1,): return ScalarAddByTensor(inplace=True)(x, y) else: - return BroadcastAdd(inplace=True)(x, y) + y = flow.experimental.broadcast_like(y, x) + return ElementwiseAdd(inplace=True)(x, y) class Asin(Module):