Skip to content
Snippets Groups Projects
Commit 4882c954 authored by liyijia's avatar liyijia
Browse files

cpu modified

parent 60612f69
No related branches found
No related tags found
No related merge requests found
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
# Url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# Path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: "CPU"
need_modelarts_dataset_unzip: False
modelarts_dataset_unzip_name: ""
# ==============================================================================
# options
n_categories: 19
n_sub_categories: 286
n_words: 74308
default_epochs: 1
epochs: -1
lr: 0.001
print_times: 1000
embedding_file: "MINDlarge_utils/embedding_all.npy"
word_dict_path: "MINDlarge_utils/word_dict_all.pkl"
category_dict_path: "MINDlarge_utils/vert_dict.pkl"
subcategory_dict_path: "MINDlarge_utils/subvert_dict.pkl"
uid2index_path: "MINDlarge_utils/uid2index.pkl"
train_dataset_path: "MINDlarge_train"
eval_dataset_path: "MINDlarge_dev"
# default option
seed: 1
platform: "CPU"
save_graphs: False
dataset: "large"
dataset_path: ""
n_browsed_news: 50
n_words_title: 16
n_words_abstract: 48
word_embedding_dim: 304
category_embedding_dim: 112
query_vector_dim: 208
n_filters: 400
window_size: 3
checkpoint_path: "" # change to naml_large_new.ckpt path when export
batch_size: 64 # change to 16 when export or infer
# train option
beta1: 0.9
beta2: 0.999
epsilon: 0.00000001 # 1e-8
neg_sample: 4 #when training, neg_sample=4, when test, neg_sample=-1
mixed: True
sink_mode: True
weight_decay: True
save_checkpoint: True
save_checkpoint_path: "./checkpoint"
dropout_ratio: 0.2
# eval option
eval_neg_sample: -1
# export option
export_file_dir: "./"
file_format: "MINDIR"
export_neg_sample: -1
# infer option
preprocess_path: "./"
result_path: "./"
label_path: "./"
---
# Help description for each configuration
# default option
seed: "random seed"
platform: "run platform, only support Ascend"
save_graphs: "whether save graphs, default is False."
dataset: "MIND dataset, support large, small and demo."
dataset_path: "MIND dataset path."
n_browsed_news: "number of browsed news per user"
n_words_title: "number of words per title"
n_words_abstract: "number of words per abstract"
word_embedding_dim: "dimension of word embedding vector"
category_embedding_dim: "dimension of category embedding vector"
query_vector_dim: "dimension of the query vector in attention"
n_filters: "number of filters in CNN"
window_size: "size of filter in CNN"
checkpoint_path: "Pre trained checkpoint path, default is None."
batch_size: "size of each batch"
# train option
beta1: "ADAM beta1"
beta2: "ADAM beta2"
epsilon: "ADAM epsilon for numerical stability"
neg_sample: "number of negative samples in negative sampling"
mixed: "whether use mixed precision, default is True."
sink_mode: "whether use dataset sink, default is True."
weight_decay: "whether use weight decay, default is True."
save_checkpoint: "whether save checkpoint, default is True."
save_checkpoint_path: "Save checkpoint path, default is checkpoint."
dropout_ratio: "ratio of dropout"
# export option
file_format: "choices in ['AIR', 'ONNX', 'MINDIR']"
...@@ -74,6 +74,7 @@ You can download the dataset and put the directory in structure as follows: ...@@ -74,6 +74,7 @@ You can download the dataset and put the directory in structure as follows:
├── MINDdemo_config.yaml # Configurations for demo ├── MINDdemo_config.yaml # Configurations for demo
├── MINDlarge_config.yaml # Configurations for large ├── MINDlarge_config.yaml # Configurations for large
├── MINDsmall_config.yaml # Configurations for small ├── MINDsmall_config.yaml # Configurations for small
├── MINDslarge_config_cpu.yaml # Configurations for large on cpu
├── ascend310_infer # application for 310 inference ├── ascend310_infer # application for 310 inference
├── train.py # training script ├── train.py # training script
├── eval.py # evaluation script ├── eval.py # evaluation script
...@@ -112,6 +113,36 @@ You can start training using python or shell scripts. The usage of shell scripts ...@@ -112,6 +113,36 @@ You can start training using python or shell scripts. The usage of shell scripts
<https://gitee.com/mindspore/models/tree/master/utils/hccl_tools>. <https://gitee.com/mindspore/models/tree/master/utils/hccl_tools>.
- running on CPU
```shell
# train using python
python train.py --config_path=[CONFIG_PATH]
--platform=[PLATFORM]
--dataset=[DATASET]
--dataset_path=[DATASET_PATH]
--save_checkpoint_path=[SAVE_CHECKPOINT_PATH]
--weight_decay=False
--sink_mode=False
# example
python train.py --config_path=MINDlarge_config_cpu.yaml --platform=CPU --dataset=large --dataset_path=MINDlarge --save_checkpoint_path=./script/checkpoint --weight_decay=False --sink_mode=False
# evaluation using python
python eval.py --config_path=[CONFIG_PATH]
--platform=[PLATFORM]
--dataset=[DATASET]
--dataset_path=[DATASET_PATH]
--checkpoint_path=[CHECKPOINT_PATH]
# example
python eval.py --config_path=MINDlarge_config_cpu.yaml --platform=CPU --dataset=large --dataset_path=MINDlarge --checkpoint_path=./script/checkpoint/naml_last.ckpt
```
- `PLATFORM` should be CPU.
- `DEVICE_ID` is the device id you want to run the network.
- `DATASET` MIND dataset, support large.
- `DATASET_PATH` is the dataset path, the structure as [Dataset](#dataset).
- `SAVE_CHECKPOINT_PATH` is a pre-trained checkpoint path to save.
- `CHECKPOINT_PATH` is a pre-trained checkpoint path.
- ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows) - ModelArts (If you want to run in modelarts, please check the official documentation of [modelarts](https://support.huaweicloud.com/modelarts/), and you can start training as follows)
- Train large dataset 1p/8p on ModelArts - Train large dataset 1p/8p on ModelArts
...@@ -260,6 +291,19 @@ bash run_infer_310.sh [NEWS_MODEL] [USER_MODEL] [DEVICE_ID] ...@@ -260,6 +291,19 @@ bash run_infer_310.sh [NEWS_MODEL] [USER_MODEL] [DEVICE_ID]
| outputs | probability | | outputs | probability |
| Accuracy | AUC: 0.6669 | | Accuracy | AUC: 0.6669 |
### Inference on CPU Performance
| Parameters | CPU |
| ----------------- | ------------------------- |
| Model Version | NAML |
| Resource | CPU |
| Uploaded Date | 8/9/2022 (month/day/year) |
| MindSpore Version | 1.8 |
| Dataset | MINDlarge |
| batch_size | 64 |
| outputs | probability |
| Accuracy | AUC: 0.6727 |
# [Description of Random Situation](#contents) # [Description of Random Situation](#contents)
<!-- In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py. --> <!-- In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py. -->
......
# Copyright 2021 Huawei Technologies Co., Ltd # Copyright 2021-2022 Huawei Technologies Co., Ltd
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -141,7 +141,10 @@ def run_train(): ...@@ -141,7 +141,10 @@ def run_train():
for _, cell in net_with_loss.cells_and_names(): for _, cell in net_with_loss.cells_and_names():
if isinstance(cell, (nn.Embedding, nn.Softmax, nn.SoftmaxCrossEntropyWithLogits)): if isinstance(cell, (nn.Embedding, nn.Softmax, nn.SoftmaxCrossEntropyWithLogits)):
cell.to_float(mstype.float32) cell.to_float(mstype.float32)
model = Model(net_with_loss, optimizer=opt, loss_scale_manager=loss_scale_manager) if config.platform == 'CPU':
model = Model(net_with_loss, optimizer=opt, loss_scale_manager=None)
else:
model = Model(net_with_loss, optimizer=opt, loss_scale_manager=loss_scale_manager)
else: else:
model = Model(net_with_loss, optimizer=opt) model = Model(net_with_loss, optimizer=opt)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment