diff --git a/official/nlp/tinybert/src/tinybert_for_gd_td.py b/official/nlp/tinybert/src/tinybert_for_gd_td.py index 076fd919f45bdf07259d729ea726bbd585585fb6..35319205ca658589ac0729dc5227ecd60390eaca 100644 --- a/official/nlp/tinybert/src/tinybert_for_gd_td.py +++ b/official/nlp/tinybert/src/tinybert_for_gd_td.py @@ -246,7 +246,7 @@ class BertTrainWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): overflow = self.process_loss_scale(cond) if not overflow: self.optimizer(grads) - return (loss, cond, scaling_sens) + return (loss, cond, scaling_sens.value()) class BertTrainCell(nn.Cell): """ @@ -470,7 +470,7 @@ class BertEvaluationWithLossScaleCell(nn.TrainOneStepWithLossScaleCell): overflow = self.process_loss_scale(cond) if not overflow: self.optimizer(grads) - return (loss, cond, scaling_sens) + return (loss, cond, scaling_sens.value()) class BertEvaluationCell(nn.Cell):