Skip to content
Snippets Groups Projects
Commit b2846f73 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!137 add gpu support for fat_deepffm

Merge pull request !137 from Shawny/deepffm
parents fc50c2eb 0ea7a874
No related branches found
No related tags found
No related merge requests found
......@@ -97,6 +97,31 @@ Fat - DeepFFM consists of three parts. The FFM component is a factorization mach
[hccl tools](https://gitee.com/mindspore/models/tree/master/utils/hccl_tools).
- running on GPU
```shell
# run training example
python train.py \
--dataset_path='data/mindrecord' \
--ckpt_path='./checkpoint/Fat-DeepFFM' \
--eval_file_name='./auc.log' \
--loss_file_name='./loss.log' \
--device_target='GPU' \
--do_eval=True > output.log 2>&1 &
# run distributed training example
bash scripts/run_distribute_train_gpu.sh 8 /dataset_path
# run evaluation example
python eval.py \
--dataset_path='dataset/mindrecord' \
--ckpt_path='./checkpoint/Fat-DeepFFM.ckpt'\
--device_target = 'GPU'\
--device_id=0 > eval_output.log 2>&1 &
OR
bash scripts/run_eval.sh 0 GPU /dataset_path /ckpt_path
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
......@@ -281,17 +306,17 @@ Inference result is saved in current path, you can find result like this in acc.
### Inference Performance
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | DeepFM |
| Resource | Ascend 910; OS Euler2.8 |
| Uploaded Date | 06/20/2021 (month/day/year) |
| MindSpore Version | 1.2.0 |
| Dataset | Criteo |
| batch_size | 1000 |
| outputs | AUC |
| AUC | 1pc: 80.90%; |
| Model for inference | 87.65M (.ckpt file) |
| Parameters | Ascend | GPU
| ------------------- | --------------------------- | ---------------------------
| Model Version | DeepFM | DeepFM
| Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G; OS Euler3.10
| Uploaded Date | 06/20/2021 (month/day/year) | 09/04/2021 (month/day/year)
| MindSpore Version | 1.2.0 | 1.3.0
| Dataset | Criteo | Criteo
| batch_size | 1000 | 1000
| outputs | AUC | AUC
| AUC | 1pc: 80.90%; | 8pc: 80.90%;
| Model for inference | 87.65M (.ckpt file) | 89.35M (.ckpt file)
# [Description of Random Situation](#contents)
......
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "Please run the script as: "
echo "bash scripts/run_distribute_train_gpu.sh DEVICE_NUM DATASET_PATH"
echo "for example: sh scripts/run_distribute_train_gpu.sh 8 /dataset_path"
echo "After running the script, the network runs in the background, The log will be generated in log/output.log"
if [ $# != 2 ]; then
echo "Usage: bash scripts/run_distribute_train_gpu.sh [DEVICE_NUM] [DATASET_PATH]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
dataset_path=$(get_real_path $2)
echo $dataset_path
if [ ! -d $dataset_path ]
then
echo "error: dataset_path=$dataset_path is not a directory."
exit 1
fi
export RANK_SIZE=$1
export DATA_URL=$2
rm -rf log
mkdir ./log
cp *.py ./log
cp -r ./src ./log
cd ./log || exit
env > env.log
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
python -u train.py \
--dataset_path=$DATA_URL \
--ckpt_path="Fat-DeepFFM" \
--eval_file_name='auc.log' \
--loss_file_name='loss.log' \
--device_target='GPU' \
--do_eval=True > output.log 2>&1 &
\ No newline at end of file
......@@ -47,22 +47,20 @@ set_seed(1)
if __name__ == '__main__':
model_config = ModelConfig()
if rank_size > 1:
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
device_id=device_id)
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True,
all_reduce_fusion_config=[9, 11])
device_id = 0
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=device_id)
if args.device_target == "Ascend":
device_id = int(os.getenv('DEVICE_ID', '0'))
if rank_size == 1 or args.device_target == "CPU":
rank_id = 0
elif rank_size > 1:
init()
rank_id = get_rank()
else:
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
device_id=device_id)
rank_size = None
rank_id = None
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=rank_size,
parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
if args.device_target == "Ascend":
context.set_auto_parallel_context(all_reduce_fusion_config=[9, 11])
print("load dataset...")
ds_train = get_mindrecord_dataset(args.dataset_path, train_mode=True, epochs=1, batch_size=model_config.batch_size,
rank_size=rank_size, rank_id=rank_id, line_per_sample=1000)
......@@ -72,12 +70,12 @@ if __name__ == '__main__':
time_callback = TimeMonitor(data_size=ds_train.get_dataset_size())
loss_callback = LossCallback(args.loss_file_name)
cb = [loss_callback, time_callback]
if rank_size == 1 or device_id == 0:
if rank_id == 0:
config_ck = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size() * model_config.epoch_size,
keep_checkpoint_max=model_config.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix=args.ckpt_path, config=config_ck)
cb += [ckpoint_cb]
if args.do_eval and device_id == 0:
if args.do_eval and rank_id == 0:
ds_test = get_mindrecord_dataset(args.dataset_path, train_mode=False)
eval_callback = AUCCallBack(model, ds_test, eval_file_path=args.eval_file_name)
cb.append(eval_callback)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment