Skip to content
Snippets Groups Projects
Commit 5c254943 authored by anzhengqi's avatar anzhengqi
Browse files

modify some network scripts

parent 04c88140
No related branches found
No related tags found
No related merge requests found
......@@ -233,10 +233,8 @@ class TrainingWrapper(nn.TrainOneStepWithLossScaleCell):
if self.reducer_flag:
grads = self.grad_reducer(grads)
cond = self.get_overflow_status(status, grads)
overflow = self.process_loss_scale(cond)
ret = (loss, cond, sens)
if not overflow:
self.optimizer(grads)
self.optimizer(grads)
return ret
......
......@@ -45,7 +45,7 @@ def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs1, warmup_epochs2,
warmup_steps4 = warmup_steps3 + steps_per_epoch * warmup_epochs4
warmup_steps5 = warmup_steps4 + steps_per_epoch * warmup_epochs5
step_radio = [1e-4, 1e-3, 1e-2, 0.1]
if config.finetune:
if hasattr(config, finetune) and config.finetune:
step_radio = [1e-4, 1e-2, 0.1, 1]
for i in range(total_steps):
if i < warmup_steps1:
......
......@@ -127,7 +127,8 @@ def main():
config.lr_end_rate = ast.literal_eval(config.lr_end_rate)
device_id = get_device_id()
if config.device_target == "Ascend":
context.set_context(mempool_block_size="31GB")
if context.get_context("mode") == context.PYNATIVE_MODE:
context.set_context(mempool_block_size="31GB")
elif config.device_target == "GPU":
set_graph_kernel_context(config.device_target)
elif config.device_target == "CPU":
......@@ -138,7 +139,7 @@ def main():
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if config.distribute:
init()
device_num = config.device_num
device_num = get_device_num()
rank = get_rank()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
......@@ -163,7 +164,7 @@ def main():
retinanet = retinanet50(backbone, config)
net = retinanetWithLossCell(retinanet, config)
init_net_param(net)
if config.finetune:
if hasattr(config, "finetune") and config.finetune:
init_net_param(net, initialize_mode='XavierUniform')
else:
init_net_param(net)
......
......@@ -46,7 +46,7 @@ from src.save_callback import SaveCallback
if config.isModelArts:
import moxing as mox
if config.net == 'resnet200' or config.net == 'resnet101':
if config.net == 'resnet200' or config.net == 'resnet101' or config.net == 'resnet50':
if config.device_target == "GPU":
config.cast_fp16 = False
......
......@@ -80,6 +80,7 @@ do
--training_shape=640 \
--weight_decay=0.016 \
--loss_scale=1024 \
--num_parallel_workers=32 \
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
cd ..
done
......@@ -67,5 +67,6 @@ python train.py \
--training_shape=640 \
--per_batch_size=32 \
--weight_decay=0.016 \
--num_parallel_workers=32 \
--lr_scheduler=cosine_annealing > log.txt 2>&1 &
cd ..
......@@ -316,9 +316,9 @@ def create_yolo_dataset(
CV.Normalize(mean, std),
hwc_to_chw
],
num_parallel_workers=num_parallel_workers
num_parallel_workers=8
)
ds = ds.batch(batch_size, num_parallel_workers=num_parallel_workers, drop_remainder=True)
ds = ds.batch(batch_size, num_parallel_workers=8, drop_remainder=True)
else:
ds = de.GeneratorDataset(
yolo_dataset,
......@@ -331,9 +331,9 @@ def create_yolo_dataset(
input_columns=["image", "img_id"],
output_columns=["image", "image_shape", "img_id"],
column_order=["image", "image_shape", "img_id"],
num_parallel_workers=num_parallel_workers
num_parallel_workers=8
)
ds = ds.map(operations=hwc_to_chw, input_columns=["image"], num_parallel_workers=num_parallel_workers)
ds = ds.map(operations=hwc_to_chw, input_columns=["image"], num_parallel_workers=8)
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.repeat(max_epoch)
return ds, len(yolo_dataset)
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment