diff --git a/official/cv/unet/train.py b/official/cv/unet/train.py index 3fe67f67561696b97c24123db9f1486133e5c2b5..472c52c8da027419143e3fc5d093c28d649508bc 100644 --- a/official/cv/unet/train.py +++ b/official/cv/unet/train.py @@ -97,12 +97,12 @@ def train_net(cross_valid_ind=1, run_distribute, config.crop, config.image_size) train_data_size = train_dataset.get_dataset_size() print("dataset length is:", train_data_size) - ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path) + ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path, f'ckpt_{rank}') save_ck_steps = train_data_size * epochs ckpt_config = CheckpointConfig(save_checkpoint_steps=save_ck_steps, keep_checkpoint_max=config.keep_checkpoint_max) ckpoint_cb = ModelCheckpoint(prefix='ckpt_{}_adam'.format(config.model_name), - directory=ckpt_save_dir+'./ckpt_{}/'.format(rank), + directory=ckpt_save_dir, config=ckpt_config) optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr, weight_decay=config.weight_decay, @@ -120,7 +120,7 @@ def train_net(cross_valid_ind=1, eval_param_dict = {"model": eval_model, "dataset": valid_dataset, "metrics_name": config.eval_metrics} eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=config.eval_interval, eval_start_epoch=config.eval_start_epoch, save_best_ckpt=True, - ckpt_directory=ckpt_save_dir+'./ckpt_{}/'.format(rank), besk_ckpt_name="best.ckpt", + ckpt_directory=ckpt_save_dir, besk_ckpt_name="best.ckpt", metrics_name=config.eval_metrics) callbacks.append(eval_cb) model.train(int(epochs / repeat), train_dataset, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode) diff --git a/official/cv/unet/unet_medical_config.yaml b/official/cv/unet/unet_medical_config.yaml index 559e156770de4d7f431b27105c3e0812894645c0..0e831f5101dc58a64846f0303ea9d2b5ad71b629 100644 --- a/official/cv/unet/unet_medical_config.yaml +++ b/official/cv/unet/unet_medical_config.yaml @@ -39,6 +39,9 @@ show_eval: False color: [[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 0, 255], [255, 255, 255]] #Eval options +eval_metrics: "dice_coeff" +eval_start_epoch: 0 +eval_interval: 1 keep_checkpoint_max: 10 eval_activate: "Softmax" eval_resize: False diff --git a/official/cv/unet/unet_medical_gpu_config.yaml b/official/cv/unet/unet_medical_gpu_config.yaml index 48b558949feaad1f7a91f8d816bc64190fc0c64c..3507211b9f268d703d7cb46c3a3abac1d0d01465 100644 --- a/official/cv/unet/unet_medical_gpu_config.yaml +++ b/official/cv/unet/unet_medical_gpu_config.yaml @@ -40,6 +40,9 @@ show_eval: False color: [[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 0, 255], [255, 255, 255]] #Eval options +eval_metrics: "dice_coeff" +eval_start_epoch: 0 +eval_interval: 1 keep_checkpoint_max: 10 eval_activate: "Softmax" eval_resize: False diff --git a/official/cv/unet/unet_nested_cell_config.yaml b/official/cv/unet/unet_nested_cell_config.yaml index 4a4e3ada1a64a0876ee7e37ac8f921ca65b54e3f..a73dd0b46b10caf7ba1b2dbf660883aca25d64e2 100644 --- a/official/cv/unet/unet_nested_cell_config.yaml +++ b/official/cv/unet/unet_nested_cell_config.yaml @@ -15,7 +15,7 @@ enable_profiling: False # Training options model_name: "unet_nested" include_background: True -run_eval: False +run_eval: True run_distribute: False dataset: "Cell_nuclei" crop: None @@ -44,6 +44,9 @@ show_eval: False color: [[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 0, 255], [255, 255, 255]] #Eval options +eval_metrics: "dice_coeff" +eval_start_epoch: 0 +eval_interval: 1 keep_checkpoint_max: 10 eval_activate: "Softmax" eval_resize: False diff --git a/official/cv/unet/unet_nested_config.yaml b/official/cv/unet/unet_nested_config.yaml index 4fd204bed58be43c4b3efadfb7169a8f15fd694d..beb5e01cc2fa4b5716b8dcf4024b8ac4f87de5f8 100644 --- a/official/cv/unet/unet_nested_config.yaml +++ b/official/cv/unet/unet_nested_config.yaml @@ -42,6 +42,9 @@ show_eval: False color: [[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 0, 255], [255, 255, 255]] #Eval options +eval_metrics: "dice_coeff" +eval_start_epoch: 0 +eval_interval: 1 keep_checkpoint_max: 10 eval_activate: "Softmax" eval_resize: False diff --git a/official/cv/unet/unet_simple_config.yaml b/official/cv/unet/unet_simple_config.yaml index fb9460979e558c9adfdb2d0c52ae54a306e7f225..5c30da0deda68a0ab303b6b800f258a279dc9c65 100644 --- a/official/cv/unet/unet_simple_config.yaml +++ b/official/cv/unet/unet_simple_config.yaml @@ -39,6 +39,9 @@ show_eval: False color: [[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 0, 255], [255, 255, 255]] #Eval options +eval_metrics: "dice_coeff" +eval_start_epoch: 0 +eval_interval: 1 keep_checkpoint_max: 10 eval_activate: "Softmax" eval_resize: False