diff --git a/research/cv/pointnet/train.py b/research/cv/pointnet/train.py index dcb36868dedab387f60d4c922d842b1fc3045227..35472fbe1725f6fb3f38f8ce62a250d58873b645 100644 --- a/research/cv/pointnet/train.py +++ b/research/cv/pointnet/train.py @@ -187,6 +187,11 @@ if __name__ == "__main__": num_classes = dataset_generator.num_seg_classes classifier = PointNetDenseCls(k=num_classes, feature_transform=args.feature_transform) classifier.set_train(True) + if context.get_context("device_target") == "Ascend": + classifier.to_float(mindspore.float16) + for _, cell in classifier.cells_and_names(): + if isinstance(cell, nn.LogSoftmax): + cell.to_float(mindspore.float32) num_batch = math.ceil(len(dataset_generator) / args.batchSize)