Skip to content
Snippets Groups Projects
Commit dd1cbb17 authored by anzhengqi's avatar anzhengqi
Browse files

modify network tinybert scritps

parent a82682fd
No related branches found
No related tags found
No related merge requests found
......@@ -108,6 +108,8 @@ def run_predistill():
dataset_size = dataset.get_dataset_size()
print('td1 dataset size: ', dataset_size)
print('td1 dataset repeatcount: ', dataset.get_repeat_count())
if args_opt.data_sink_steps == -1:
args_opt.data_sink_steps = dataset_size
if args_opt.enable_data_sink == 'true':
repeat_count = args_opt.td_phase1_epoch_size * dataset_size // args_opt.data_sink_steps
time_monitor_steps = args_opt.data_sink_steps
......@@ -173,6 +175,8 @@ def run_task_distill(ckpt_file):
dataset_size = train_dataset.get_dataset_size()
print('td2 train dataset size: ', dataset_size)
print('td2 train dataset repeatcount: ', train_dataset.get_repeat_count())
if args_opt.data_sink_steps == -1:
args_opt.data_sink_steps = dataset_size
if args_opt.enable_data_sink == 'true':
repeat_count = args_opt.td_phase2_epoch_size * train_dataset.get_dataset_size() // args_opt.data_sink_steps
time_monitor_steps = args_opt.data_sink_steps
......
......@@ -39,7 +39,7 @@ python ${PROJECT_DIR}/../run_task_distill.py \
--td_phase2_epoch_size=3 \
--do_shuffle="true" \
--enable_data_sink="true" \
--data_sink_steps=100 \
--data_sink_steps=-1 \
--save_ckpt_step=100 \
--max_ckpt_num=1 \
--load_teacher_ckpt_path="" \
......
......@@ -87,10 +87,8 @@ class EvalCallBack(Callback):
self.global_acc = 0.0
self.dataset = dataset
def step_end(self, run_context):
def epoch_end(self, run_context):
"""step end and do evaluation"""
cb_params = run_context.original_args()
if cb_params.cur_step_num % 100 == 0:
callback = Accuracy()
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
for data in self.dataset.create_dict_iterator(num_epochs=1):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment