Skip to content
Snippets Groups Projects
Unverified Commit ae442bbc authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!3197 ONNX:DDPG

Merge pull request !3197 from 王一凡/ddpg
parents fc420161 7b7d8b3f
Branches r1.5
No related tags found
No related merge requests found
......@@ -98,6 +98,7 @@ ddpg
├── run_910_train.sh # 训练shell脚本
├── run_910_verify.sh # ddpg Ascend 310推理shell脚本
├── run_infer_310 .sh # ddpg Ascend 310 推理shell脚本
├── run_ONNX_infer .sh # ddpg GPU ONNX 推理shell脚本
├── src
├── agent.py # ddpg 模型脚本
├── ac_net.py # ddpg 基础网络
......@@ -106,6 +107,7 @@ ddpg
├── train.py # ddpg 训练脚本
├── verify.py # ddpg Ascend 910推理脚本
├── verify_310.py # ddpg Ascend 310推理脚本
├── verify_ONNX.py # ddpg ONNX推理脚本
├── export.py # ddpg 模型导出脚本
```
......@@ -227,6 +229,37 @@ STEP_TEST = 100 # 测试单轮步数
bash run_infer_310.sh ../test.mindir ../output 0
```
- GPU 运行ONNX推理脚本
在GPU环境运行时评估
评估所需ckpt获取地址:[获取地址](https://www.mindspore.cn/resources/hub/details?MindSpore/1.5/ddpg_critictarget_none)
(https://www.mindspore.cn/resources/hub/details?MindSpore/1.5/ddpg_criticnet_none)
(https://www.mindspore.cn/resources/hub/details?MindSpore/1.5/ddpg_actornet_none)
(https://www.mindspore.cn/resources/hub/details?MindSpore/1.5/ddpg_actortarget_none)
在执行评估之前,需要通过export.py导出onnx文件。
```python
python export.py
```
在运行以下命令之前,请检查用于评估的检查点路径。
请将检查点路径设置为相对路径,例如“../actornet.ckpt”。
```shell
# 先编译C++文件,再执行Python推理
bash run_onnx_infer.sh
# example
cd scripts
bash run_onnx_infer.sh
```
上述python命令将在后台运行,您可以通过verify_onnx/verify_onnx.log文件查看结果。测试的准确性如下:
```bash
Final Average Reward: 6.260091564789905
```
# 随机情况说明
在推理过程和训练过程中,我们都使用到gym环境下的随机种子。
......
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
Gamma: 0.99
GAMMA: 0.99
EPISODES: 5000
EP_STEPS: 200
MEMORY_CAPACITY: 10000
......@@ -7,6 +7,6 @@ BATCH_SIZE: 32
EP_TEST: 5
STEP_TEST: 100
TAU: 0.001
LR_ACTOR: 1e-3
LR_CRITIC: 1e-4
LR_ACTOR: 0.001
LR_CRITIC: 0.0001
CRITIC_DECAY: 0.01
\ No newline at end of file
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -14,6 +14,7 @@
# ============================================================================
"""export"""
import argparse
import math
import gym
import mindspore as ms
......@@ -21,6 +22,12 @@ from mindspore import Tensor, export, context, nn, ops
from mindspore import load_checkpoint
from mindspore.common.initializer import Uniform
parser = argparse.ArgumentParser(description='ddpg')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--file_name", type=str, default="ddpg", help="output file name.")
parser.add_argument('--file_format', type=str, choices=["MINDIR", "AIR", "ONNX"], default='MINDIR', help='file format')
args = parser.parse_args()
class ActorNet(nn.Cell):
"""
......@@ -52,8 +59,12 @@ class ActorNet(nn.Cell):
def run_export():
"""export"""
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
env = gym.make('Pendulum-v0')
if args.file_format == "MINDIR":
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
env = gym.make('Pendulum-v0')
elif args.file_format == "ONNX":
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
env = gym.make('Pendulum-v1')
env = env.unwrapped
env.seed(1)
state_dim = env.observation_space.shape[0]
......@@ -62,10 +73,16 @@ def run_export():
state = Tensor(state, ms.float32)
expand_dims = ops.ExpandDims()
state = expand_dims(state, 0)
actor_net = ActorNet(state_dim, action_dim)
load_checkpoint("actor_net.ckpt", net=actor_net)
export(actor_net, state, file_name="test", file_format="MINDIR")
print("export MINDIR file at {}".format("./test.mindir"))
if args.file_format == "MINDIR":
export(actor_net, state, file_name="test", file_format="MINDIR")
print("export Actor file at {}".format("./test.mindir"))
elif args.file_format == "ONNX":
export(actor_net, state, file_name="actornet", file_format="ONNX")
print("export Actor file at {}".format("./Actor.ONNX"))
if __name__ == '__main__':
run_export()
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 1 ];
then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash run_onnx_verify.sh"
echo "For example: bash run_onnx_verify.sh"
echo "It is better to use the absolute path."
echo "=============================================================================================================="
exit 2
fi
set -e
DEVICE_ID=$1
export DEVICE_ID=$DEVICE_ID
EXEC_PATH=$(pwd)
echo "$EXEC_PATH"
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
cd ../
rm -rf verify_onnx/
mkdir verify_onnx
cd ./verify_onnx
mkdir src
cd ../
cp ./default_paras.yaml ./verify_onnx
cp ./verify_onnx.py ./verify_onnx
cp ./src/*.py ./verify_onnx/src
cp ./actornet.onnx ./verify_onnx
cd ./verify_onnx
env > env0.log
echo "Verify begin."
python verify_onnx.py > ./verify_onnx.log 2>&1 &
if [ $? -eq 0 ];then
echo "evaling success"
else
echo "evaling failed"
exit 2
fi
echo "finish"
cd ../
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""ONNX verify"""
import argparse
import gym
from src.config import config
from mindspore import context
import onnxruntime as ort
import numpy as np
parser = argparse.ArgumentParser(description='ONNX ddpg Example')
parser.add_argument('--device_target', type=str, default="GPU", choices=['Ascend', 'GPU'],
help='device where the code will be implemented (default: Ascend)')
parser.add_argument('--device_id', type=int, default=0, help='if is test, must provide\
path where the trained ckpt file')
args = parser.parse_args()
context.set_context(device_id=args.device_id)
EP_TEST = config.EP_TEST
STEP_TEST = config.STEP_TEST
REWORD_SCOPE = 16.2736044
def create_session(checkpoint_path, target_device):
'''
create onnx session
'''
if target_device == 'GPU':
providers = ['CUDAExecutionProvider']
elif target_device == 'CPU':
providers = ['CPUExecutionProvider']
else:
raise ValueError(f'Unsupported target device {target_device!r}. Expected one of: "CPU", "GPU"')
session = ort.InferenceSession(checkpoint_path, providers=providers)
input_name = session.get_inputs()[0].name
return session, input_name
def verify():
""" verify"""
env = gym.make('Pendulum-v1')
env = env.unwrapped
env.seed(1)
session, input_name = create_session('./actornet.onnx', 'GPU')
rewards = []
for i in range(EP_TEST):
reward_sum = 0
state = env.reset()
for j in range(STEP_TEST):
state = np.expand_dims(state, axis=0)
state = state.astype(np.float32)
action = session.run(None, {input_name: state})[0]
action = action[np.argmax(action)]
next_state, reward, _, _, _ = env.step(action)
reward_sum += reward
state = next_state
if j == STEP_TEST - 1:
print('Episode: ', i, ' Reward:', reward_sum / REWORD_SCOPE)
rewards.append(reward_sum)
break
print('Final Average Reward: ', sum(rewards) / (len(rewards) * REWORD_SCOPE))
if __name__ == '__main__':
verify()
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment