Skip to content
Snippets Groups Projects
Commit 10c99104 authored by jialing's avatar jialing
Browse files

performance of pointnet

parent 9d80a7fe
No related branches found
No related tags found
No related merge requests found
......@@ -187,6 +187,11 @@ if __name__ == "__main__":
num_classes = dataset_generator.num_seg_classes
classifier = PointNetDenseCls(k=num_classes, feature_transform=args.feature_transform)
classifier.set_train(True)
if context.get_context("device_target") == "Ascend":
classifier.to_float(mindspore.float16)
for _, cell in classifier.cells_and_names():
if isinstance(cell, nn.LogSoftmax):
cell.to_float(mindspore.float32)
num_batch = math.ceil(len(dataset_generator) / args.batchSize)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment