diff --git a/main.py b/main.py index f8ed9fcc910410e37e010058b40fbf7ac4236f86..a1119ef49be5204c71bac17f4799cd31ead76a07 100644 --- a/main.py +++ b/main.py @@ -37,6 +37,10 @@ net_opt = nn.Adam(net.trainable_params(), learning_rate=lr_iter) config_ck = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=10) ckpoint = ModelCheckpoint(prefix="checkpoint_lenet", config=config_ck) +def cal_hist(a, b, n): + k = (a >= 0) & (a < n) + return np.bincount(n * a[k].astype(np.int32) + b[k], minlength=n ** 2).reshape(n, n) + def resize_long(img, long_size=513): h, w, _ = img.shape if h > w: