diff --git a/official/cv/octsqueeze/eval.py b/official/cv/octsqueeze/eval.py index ac9a6a4c2007d21ee12699c6249ab64c23af2b51..3d5551999dfd67e73751ea47a6a7f9f4dbfb4675 100644 --- a/official/cv/octsqueeze/eval.py +++ b/official/cv/octsqueeze/eval.py @@ -41,6 +41,7 @@ def def_arguments(): '--model', '-m', type=str, default='/home/OctSqueeze/checkpoint/octsqueeze.ckpt', help='route of model') parser.add_argument('--device_target', type=str, default="Ascend", choices=['Ascend', 'GPU', 'CPU'], help='device where the code will be implemented') + parser.add_argument('--device_id', type=int, default=0) return parser.parse_args() @@ -51,10 +52,6 @@ def compression_decompression_simulation(dataset_path, precision_oct): if filename.endswith('.bin'): frames.append('{}'.format(filename)) - # Load network - # Configure operation information - context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=0) - ## Create networks net = network.OctSqueezeNet() param_dict = load_checkpoint(args.model) @@ -148,6 +145,7 @@ def compression_decompression_simulation(dataset_path, precision_oct): if __name__ == '__main__': # Evaluate test data at four bitrate whose max point-to-point error should less then [0.01 0.02, 0.04, 0.08] args = def_arguments() + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) precision_list = [0.01, 0.02, 0.04, 0.08] bpip_CD_all = np.empty([len(precision_list), 2]) if not os.path.exists(args.compression): os.makedirs(args.compression)