Skip to content
Snippets Groups Projects
Commit abc34438 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!1283 modify yolov5 train scripts to fix fps performance

Merge pull request !1283 from anzhengqi/usability-modify-yolov5
parents 5e095c04 4bb24a52
No related branches found
No related tags found
No related merge requests found
......@@ -91,17 +91,17 @@ def run_train():
network = nn.TrainOneStepCell(network, opt, config.loss_scale // 2)
network.set_train()
data_loader = ds.create_dict_iterator()
data_loader = ds.create_tuple_iterator(do_copy=False)
first_step = True
t_end = time.time()
for epoch_idx in range(config.max_epoch):
for step_idx, data in enumerate(data_loader):
images = data["image"]
images = data[0]
input_shape = images.shape[2:4]
input_shape = ms.Tensor(tuple(input_shape[::-1]), ms.float32)
loss = network(images, data['bbox1'], data['bbox2'], data['bbox3'], data['gt_box1'], data['gt_box2'],
data['gt_box2'], input_shape)
loss = network(images, data[2], data[3], data[4], data[5], data[6],
data[7], input_shape)
loss_meter.update(loss.asnumpy())
# it is used for loss, performance output per config.log_interval steps.
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment