diff --git a/main.py b/main.py index 5cb2a28e27f31a6aa2de42d200ee3e6b25e9c39f..59e7c5a703628e465531f8f97e176a0f17df0d63 100644 --- a/main.py +++ b/main.py @@ -19,6 +19,7 @@ from dataset import GetDatasetGenerator from loss import SoftmaxCrossEntropyLoss from learning_rates import exponential_lr, cosine_lr + context.set_context(mode=context.PYNATIVE_MODE, save_graphs=False, device_target='Ascend', device_id=7) # net = UNetMedical(n_channels=3, n_classes=6) @@ -92,7 +93,8 @@ def eval_batch(eval_net, img_lst, crop_size=513, flip=True): resize_hw = [] for l in range(batch_size): img_ = img_lst[l] - img_, resize_h, resize_w = pre_process(img_, crop_size) + resize_h, resize_w, _ = img_.shape + # img_, resize_h, resize_w = pre_process(img_, crop_size) batch_img[l] = img_ resize_hw.append([resize_h, resize_w])