Skip to content
Snippets Groups Projects
Commit 9581d69c authored by zhaoting's avatar zhaoting
Browse files

set run_eval true for unet++

parent c1bb6424
No related branches found
No related tags found
No related merge requests found
......@@ -97,12 +97,12 @@ def train_net(cross_valid_ind=1,
run_distribute, config.crop, config.image_size)
train_data_size = train_dataset.get_dataset_size()
print("dataset length is:", train_data_size)
ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path)
ckpt_save_dir = os.path.join(config.output_path, config.checkpoint_path, f'ckpt_{rank}')
save_ck_steps = train_data_size * epochs
ckpt_config = CheckpointConfig(save_checkpoint_steps=save_ck_steps,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix='ckpt_{}_adam'.format(config.model_name),
directory=ckpt_save_dir+'./ckpt_{}/'.format(rank),
directory=ckpt_save_dir,
config=ckpt_config)
optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr, weight_decay=config.weight_decay,
......@@ -120,7 +120,7 @@ def train_net(cross_valid_ind=1,
eval_param_dict = {"model": eval_model, "dataset": valid_dataset, "metrics_name": config.eval_metrics}
eval_cb = EvalCallBack(apply_eval, eval_param_dict, interval=config.eval_interval,
eval_start_epoch=config.eval_start_epoch, save_best_ckpt=True,
ckpt_directory=ckpt_save_dir+'./ckpt_{}/'.format(rank), besk_ckpt_name="best.ckpt",
ckpt_directory=ckpt_save_dir, besk_ckpt_name="best.ckpt",
metrics_name=config.eval_metrics)
callbacks.append(eval_cb)
model.train(int(epochs / repeat), train_dataset, callbacks=callbacks, dataset_sink_mode=dataset_sink_mode)
......
......@@ -39,6 +39,9 @@ show_eval: False
color: [[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 0, 255], [255, 255, 255]]
#Eval options
eval_metrics: "dice_coeff"
eval_start_epoch: 0
eval_interval: 1
keep_checkpoint_max: 10
eval_activate: "Softmax"
eval_resize: False
......
......@@ -40,6 +40,9 @@ show_eval: False
color: [[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 0, 255], [255, 255, 255]]
#Eval options
eval_metrics: "dice_coeff"
eval_start_epoch: 0
eval_interval: 1
keep_checkpoint_max: 10
eval_activate: "Softmax"
eval_resize: False
......
......@@ -15,7 +15,7 @@ enable_profiling: False
# Training options
model_name: "unet_nested"
include_background: True
run_eval: False
run_eval: True
run_distribute: False
dataset: "Cell_nuclei"
crop: None
......@@ -44,6 +44,9 @@ show_eval: False
color: [[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 0, 255], [255, 255, 255]]
#Eval options
eval_metrics: "dice_coeff"
eval_start_epoch: 0
eval_interval: 1
keep_checkpoint_max: 10
eval_activate: "Softmax"
eval_resize: False
......
......@@ -42,6 +42,9 @@ show_eval: False
color: [[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 0, 255], [255, 255, 255]]
#Eval options
eval_metrics: "dice_coeff"
eval_start_epoch: 0
eval_interval: 1
keep_checkpoint_max: 10
eval_activate: "Softmax"
eval_resize: False
......
......@@ -39,6 +39,9 @@ show_eval: False
color: [[0, 0, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255], [0, 255, 255], [255, 0, 255], [255, 255, 255]]
#Eval options
eval_metrics: "dice_coeff"
eval_start_epoch: 0
eval_interval: 1
keep_checkpoint_max: 10
eval_activate: "Softmax"
eval_resize: False
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment