Skip to content
Snippets Groups Projects
Unverified Commit ce32d64b authored by zhaoting's avatar zhaoting Committed by Gitee
Browse files

!3312 [西安交通大学][高校贡献][Mindspore][pointnet性能提升]

Merge pull request !3312 from jialing/master
parents aca98aac 10c99104
No related branches found
No related tags found
No related merge requests found
......@@ -187,6 +187,11 @@ if __name__ == "__main__":
num_classes = dataset_generator.num_seg_classes
classifier = PointNetDenseCls(k=num_classes, feature_transform=args.feature_transform)
classifier.set_train(True)
if context.get_context("device_target") == "Ascend":
classifier.to_float(mindspore.float16)
for _, cell in classifier.cells_and_names():
if isinstance(cell, nn.LogSoftmax):
cell.to_float(mindspore.float32)
num_batch = math.ceil(len(dataset_generator) / args.batchSize)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment