diff --git a/research/cv/pointnet/train.py b/research/cv/pointnet/train.py
index dcb36868dedab387f60d4c922d842b1fc3045227..35472fbe1725f6fb3f38f8ce62a250d58873b645 100644
--- a/research/cv/pointnet/train.py
+++ b/research/cv/pointnet/train.py
@@ -187,6 +187,11 @@ if __name__ == "__main__":
num_classes = dataset_generator.num_seg_classes
classifier = PointNetDenseCls(k=num_classes, feature_transform=args.feature_transform)
classifier.set_train(True)
+ if context.get_context("device_target") == "Ascend":
+ classifier.to_float(mindspore.float16)
+ for _, cell in classifier.cells_and_names():
+ if isinstance(cell, nn.LogSoftmax):
+ cell.to_float(mindspore.float32)
num_batch = math.ceil(len(dataset_generator) / args.batchSize)