diff --git a/official/cv/deeptext/train.py b/official/cv/deeptext/train.py index 7142765fb7469de771b8ed5aebbd2298f33b3c95..ea8ab046a58ea2857738613bd4e9c1d8cc9fb6da 100644 --- a/official/cv/deeptext/train.py +++ b/official/cv/deeptext/train.py @@ -168,6 +168,8 @@ def run_train(): net.to_float(mstype.float16) loss = LossNet() + if device_type == "Ascend": + loss.to_float(mstype.float32) lr = Tensor(dynamic_lr(config, rank_size=device_num), mstype.float32) opt = Momentum(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum,