diff --git a/research/cv/CycleGAN/src/models/cycle_gan.py b/research/cv/CycleGAN/src/models/cycle_gan.py index a285768c56648be474fb9deabfef7bf1cc0a288d..11b9e3a10f69a3c8673d587e13025a6df7f30c72 100644 --- a/research/cv/CycleGAN/src/models/cycle_gan.py +++ b/research/cv/CycleGAN/src/models/cycle_gan.py @@ -189,9 +189,12 @@ class TrainOneStepG(nn.Cell): def construct(self, img_A, img_B): weights = self.weights - fake_A, fake_B, lg, lga, lgb, lca, lcb, lia, lib = self.G(img_A, img_B) - sens = ops.Fill()(ops.DType()(lg), ops.Shape()(lg), self.sens) - grads_g = self.grad(self.net, weights)(img_A, img_B, sens) + out = self.G(img_A, img_B) + lg, fake_A, fake_B, lga, lgb, lca, lcb, lia, lib = out + sens_tuple = (ops.ones_like(lg) * self.sens,) + for i in range(1, len(out)): + sens_tuple += (ops.zeros_like(out[i]),) + grads_g = self.grad(self.G, weights)(img_A, img_B, sens_tuple) if self.reducer_flag: # apply grad reducer on grads grads_g = self.grad_reducer(grads_g) diff --git a/research/cv/CycleGAN/src/models/losses.py b/research/cv/CycleGAN/src/models/losses.py index 07fc794a9b6d7ea9f5cec7150d3a21cb92386bba..462ae531a226916a792a33b119f1285f2aaf1de1 100644 --- a/research/cv/CycleGAN/src/models/losses.py +++ b/research/cv/CycleGAN/src/models/losses.py @@ -129,7 +129,7 @@ class GeneratorLoss(nn.Cell): loss_idt_A = 0 loss_idt_B = 0 loss_G = loss_G_A + loss_G_B + loss_C_A + loss_C_B + loss_idt_A + loss_idt_B - return (fake_A, fake_B, loss_G, loss_G_A, loss_G_B, loss_C_A, loss_C_B, loss_idt_A, loss_idt_B) + return (loss_G, fake_A, fake_B, loss_G_A, loss_G_B, loss_C_A, loss_C_B, loss_idt_A, loss_idt_B) class DiscriminatorLoss(nn.Cell):