Skip to content
Snippets Groups Projects
Unverified Commit 5af7d1be authored by ZZK's avatar ZZK Committed by GitHub
Browse files

fix bug about pos_weight (#4768)


Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent 92337efb
No related branches found
No related tags found
No related merge requests found
......@@ -3848,13 +3848,6 @@ def bce_with_logits_loss(
reduction
)
assert pos_weight.shape[0] == input.shape[-1], (
"The length of `pos_weight` must be equal to the number of classes. "
"Found the length of pos_weight {} vs classes {}".format(
pos_weight.shape[0], input.shape[-1]
)
)
if name is None:
name = id_util.UniqueStr("BCEWithLogitsLoss")
......@@ -3863,6 +3856,12 @@ def bce_with_logits_loss(
_neg_max_val = flow.math.negative(_max_val)
if pos_weight:
assert pos_weight.shape[0] == input.shape[-1], (
"The length of `pos_weight` must be equal to the number of classes. "
"Found the length of pos_weight {} vs classes {}".format(
pos_weight.shape[0], input.shape[-1]
)
)
_log_weight = ((pos_weight - 1) * target) + 1
_loss = (1 - target) * input + _log_weight * (
flow.math.log(
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment