diff --git a/research/cv/pointnet/train.py b/research/cv/pointnet/train.py
index dcb36868dedab387f60d4c922d842b1fc3045227..35472fbe1725f6fb3f38f8ce62a250d58873b645 100644
--- a/research/cv/pointnet/train.py
+++ b/research/cv/pointnet/train.py
@@ -187,6 +187,11 @@ if __name__ == "__main__":
     num_classes = dataset_generator.num_seg_classes
     classifier = PointNetDenseCls(k=num_classes, feature_transform=args.feature_transform)
     classifier.set_train(True)
+    if context.get_context("device_target") == "Ascend":
+        classifier.to_float(mindspore.float16)
+        for _, cell in classifier.cells_and_names():
+            if isinstance(cell, nn.LogSoftmax):
+                cell.to_float(mindspore.float32)
 
     num_batch = math.ceil(len(dataset_generator) / args.batchSize)