Skip to content
Snippets Groups Projects
Unverified Commit 9931fa60 authored by i-robot's avatar i-robot Committed by Gitee
Browse files

!2554 [华为大学][高校贡献][MindSpore]-NCF问题单回归修改

Merge pull request !2554 from 迎接光辉岁月/r1.2
parents c1b67dfc 56581dc2
No related branches found
No related tags found
No related merge requests found
Showing with 282 additions and 3 deletions
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Training entry file"""
import argparse
import ast
import numpy as np
from absl import logging
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore import context, export, Tensor
from mindspore.common import set_seed
from ncf import NCFModel, PredictWithSigmoid
import src.constants as rconst
from src.config import cfg
set_seed(1)
logging.set_verbosity(logging.INFO)
parser = argparse.ArgumentParser(description='ncf export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--dataset", type=str, default="ml-1m", choices=["ml-1m", "ml-20m"], help="Dataset.")
parser.add_argument("--file_name", type=str, default="ncf", help="output file name.")
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
parser.add_argument("--device_target", type=str, default="Ascend",
choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)")
parser.add_argument("--is_row_vector_input", type=ast.literal_eval, default=True,
help="Change model input into row vector for MindX SDK inference")
args = parser.parse_args()
def export_for_infer():
"""export method"""
if args.dataset == "ml-1m":
num_eval_users = 6040
num_eval_items = 3706
elif args.dataset == "ml-20m":
num_eval_users = 138493
num_eval_items = 26744
else:
raise ValueError("not supported dataset")
layers = cfg.layers
num_factors = cfg.num_factors
ncf_net = NCFModel(num_users=num_eval_users,
num_items=num_eval_items,
num_factors=num_factors,
model_layers=layers,
mf_regularization=0,
mlp_reg_layers=[0.0, 0.0, 0.0, 0.0],
mf_dim=16,
is_row_vector_input=args.is_row_vector_input)
frozen_to_air_args = {'ckpt_file': args.ckpt_file,
'topk': rconst.TOP_K,
'num_eval_neg': rconst.NUM_EVAL_NEGATIVES,
'file_name': args.file_name,
'file_format': args.file_format}
frozen_to_air(ncf_net, frozen_to_air_args)
def frozen_to_air(net, args_net):
"""frozen net parameters in the format of air"""
param_dict = load_checkpoint(args_net.get("ckpt_file"))
load_param_into_net(net, param_dict)
network = PredictWithSigmoid(net, args_net.get("topk"), args_net.get("num_eval_neg"))
users = Tensor(np.zeros([1, cfg.eval_batch_size]).astype(np.int32))
items = Tensor(np.zeros([1, cfg.eval_batch_size]).astype(np.int32))
masks = Tensor(np.zeros([1, cfg.eval_batch_size]).astype(np.float32))
input_data = [users, items, masks]
export(network, *input_data, file_name=args_net.get("file_name"), file_format=args_net.get("file_format"))
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
if __name__ == '__main__':
export_for_infer()
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
official/recommend/ncf/modelarts/res/config_train_work.png

391 KiB

official/recommend/ncf/modelarts/res/create_algorithm.png

284 KiB

official/recommend/ncf/modelarts/res/logs.png

859 KiB

official/recommend/ncf/modelarts/res/obs_directory.png

135 KiB

#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "Please run the script as: "
echo "sh scripts/run_distribute_train.sh DEVICE_NUM DATASET_PATH RANK_TABLE_FILE"
echo "for example: sh scripts/run_distribute_train.sh 8 /dataset_path /rank_table_8p.json"
current_exec_path=$(pwd)
echo ${current_exec_path}
export RANK_SIZE=$1
data_path=$2
export RANK_TABLE_FILE=$3
for((i=0;i<=RANK_SIZE;i++));
do
rm ${current_exec_path}/device_$i/ -rf
mkdir ${current_exec_path}/device_$i
cd ${current_exec_path}/device_$i || exit
export RANK_ID=$i
export DEVICE_ID=$i
python3 -u ${current_exec_path}/train_for_infer.py \
--data_path $data_path \
--dataset 'ml-1m' \
--train_epochs 50 \
--output_path './output/' \
--loss_file_name 'loss.log' \
--checkpoint_path './checkpoint/' \
--device_target="Ascend" \
--device_id=$i \
--is_distributed=1 \
>log_$i.log 2>&1 &
done
#!/bin/bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "Please run the script as: "
echo "sh scripts/run_train.sh DATASET_PATH CKPT_FILE"
echo "for example: sh scripts/run_train.sh /dataset_path /ncf.ckpt"
data_path=$1
ckpt_file=$2
python3 ./train_for_infer.py --data_path $data_path --dataset 'ml-1m' --train_epochs 20 --batch_size 256 --checkpoint_path $ckpt_file
# Copyright 2021 Huawei Technologies Co., Ltd
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Training entry file"""
import os
import argparse
import ast
from absl import logging
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore import context, Model
from mindspore.context import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from mindspore.common import set_seed
from dataset import create_dataset
from ncf import NCFModel, NetWithLossClass, TrainStepWrap
from src.config import cfg
set_seed(1)
logging.set_verbosity(logging.INFO)
parser = argparse.ArgumentParser(description='NCF')
parser.add_argument("--data_path", type=str, default="./dataset/") # The location of the input data.
parser.add_argument("--dataset", type=str, default="ml-1m", choices=["ml-1m", "ml-20m"]) # Dataset to be trained and evaluated. ["ml-1m", "ml-20m"]
parser.add_argument("--train_epochs", type=int, default=14) # The number of epochs used to train.
parser.add_argument("--batch_size", type=int, default=256) # Batch size for training and evaluation
parser.add_argument("--num_neg", type=int, default=4) # The Number of negative instances to pair with a positive instance.
parser.add_argument("--checkpoint_path", type=str, default="./checkpoint/")
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument('--device_id', type=int, default=0, help='device id of GPU or Ascend. (Default: None)')
parser.add_argument('--is_distributed', type=int, default=0, help='if multi device')
parser.add_argument('--rank', type=int, default=0, help='local rank of distributed')
parser.add_argument('--group_size', type=int, default=1, help='world size of distributed')
parser.add_argument("--is_row_vector_input", type=ast.literal_eval, default=True,
help="Change model input into row vector for MindX SDK inference")
args = parser.parse_args()
def train():
"""train method"""
if args.is_distributed:
if args.device_target == "Ascend":
init()
context.set_context(device_id='0')
elif args.device_target == "GPU":
init()
args.rank = get_rank()
args.group_size = get_group_size()
device_num = args.group_size
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
parameter_broadcast=True, gradients_mean=True)
else:
context.set_context(device_id=args.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
layers = cfg.layers
num_factors = cfg.num_factors
epochs = args.train_epochs
if not os.path.exists(args.checkpoint_path):
os.makedirs(args.checkpoint_path)
CKPT_SAVE_DIR = args.checkpoint_path
ds_train, num_train_users, num_train_items = create_dataset(test_train=True, data_dir=args.data_path,
dataset=args.dataset, train_epochs=1,
batch_size=args.batch_size, num_neg=args.num_neg,
row_vector=args.is_row_vector_input)
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
ncf_net = NCFModel(num_users=num_train_users,
num_items=num_train_items,
num_factors=num_factors,
model_layers=layers,
mf_regularization=0,
mlp_reg_layers=[0.0, 0.0, 0.0, 0.0],
mf_dim=16,
is_row_vector_input=args.is_row_vector_input)
loss_net = NetWithLossClass(ncf_net)
train_net = TrainStepWrap(loss_net, ds_train.get_dataset_size() * (epochs + 1))
train_net.set_train()
model = Model(train_net)
callback = LossMonitor(per_print_times=ds_train.get_dataset_size())
ckpt_config = CheckpointConfig(save_checkpoint_steps=(4970845+args.batch_size-1)//(args.batch_size),
keep_checkpoint_max=100)
ckpoint_cb = ModelCheckpoint(prefix='ncf', directory=CKPT_SAVE_DIR, config=ckpt_config)
model.train(epochs,
ds_train,
callbacks=[TimeMonitor(ds_train.get_dataset_size()), callback, ckpoint_cb],
dataset_sink_mode=True)
if __name__ == '__main__':
train()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment