diff --git a/research/recommend/mmoe/README_CN.md b/research/recommend/mmoe/README_CN.md
index 78073a84296bbf8d526bab77b70cb447ad86d0ec..3b4754c3c1c2321d5ce7f67398831adb789cacc8 100644
--- a/research/recommend/mmoe/README_CN.md
+++ b/research/recommend/mmoe/README_CN.md
@@ -13,11 +13,11 @@
- [脚本参数](#脚本参数)
- [训练过程](#训练过程)
- [用法](#用法)
- - [Ascend处理器或GPU环境运行](#Ascend处理器或GPU环境运行)
+ - [Ascend处理器或GPU环境或CPU环境运行](#Ascend处理器或GPU环境或CPU环境运行)
- [结果](#结果)
- [评估过程](#评估过程)
- [评估用法](#评估用法)
- - [Ascend处理器或GPU环境运行评估](#Ascend处理器或GPU环境运行评估)
+ - [Ascend处理器或GPU环境或CPU环境运行评估](#Ascend处理器或GPU环境或CPU环境运行评估)
- [结果](#结果)
- [Ascend310推理过程](#推理过程)
- [导出MindIR](#导出MindIR)
@@ -156,8 +156,10 @@ Usage: bash run_standalone_eval_gpu.sh [DATA_PATH] [CKPT_PATH] [DEVICE_ID] [CONF
└── mmoe_utils.py # 每一层架构
├── eval.py # 910评估网络
├── default_config.yaml # 默认的参数配置
+ ├── default_config_cpu.yaml # 针对CPU环境默认的参数配置
├── default_config_gpu.yaml # 针对GPU环境默认的参数配置
├── export.py # 910导出网络
+ ├── fine_tune.py # CPU训练网络
├── postprocess.py # 310推理精度计算
├── preprocess.py # 310推理前数据处理
└── train.py # 910训练网络
@@ -182,6 +184,21 @@ Usage: bash run_standalone_eval_gpu.sh [DATA_PATH] [CKPT_PATH] [DEVICE_ID] [CONF
"warmup_epochs":5, # 热身周期
```
+- CPU环境下参数设置
+
+```Python
+"num_features":499, # 每一条数据的特征数
+"num_experts":8, # 专家数
+"units":4, # 每一层的unit数
+"batch_size":32, # 输入张量的批次大小
+"epoch_size":10, # 训练周期大小
+"learning_rate":0.0001, # 初始学习率
+"save_checkpoint":True, # 是否保存检查点
+"save_checkpoint_epochs":1, # 两个检查点之间的周期间隔;默认情况下,最后一个检查点将在最后一个周期完成后保存
+"keep_checkpoint_max":10, # 只保存最后一个keep_checkpoint_max检查点
+"warmup_epochs":5, # 热身周期
+```
+
# 训练过程
## 用法
@@ -231,6 +248,26 @@ Usage: bash run_standalone_eval_gpu.sh [DATA_PATH] [CKPT_PATH] [DEVICE_ID] [CONF
训练结果保存在示例路径中,文件夹名称以“train”或“train_parallel”开头。您可在此路径下的日志中找到检查点文件以及结果,如下所示。
+## CPU环境运行
+
+### 数据处理
+
+[根据提供的数据集链接加载数据集](http://github.com/drawbridge/keras-mmoe)在train.py文件同级目录下新建data文件夹,执行src中文件
+
+```shell
+python data.py --local_data_path ../data
+```
+
+即可得到所需的测试集与验证集。
+
+### 用法
+
+您可以通过python脚本开始训练
+
+```shell
+python train.py --config_path ./default_config_cpu.yaml
+```
+
## 结果
- 使用census-income数据集训练MMoE
@@ -270,6 +307,27 @@ infer data finished, start eval...
result : income_auc=0.9876599788311389, marital_auc=0.9663552616198483, use time 1s
The best income_auc is 0.9876599788311389, the best marital_auc is 0.9663552616198483, the best income_marital_auc_avg is 0.9770076202254936
...
+
+# 单卡CPU训练结果
+epoch: 1 step: 6235, loss is 0.8481878638267517
+Train epoch time: 27365.547 ms, per step time: 4.389 ms
+start infer...
+infer data finished, start eval...
+result : income_auc=0.9528846425952942, marital_auc=0.7993896372126021, use time 8s
+The best income_auc is 0.9528846425952942, the best marital_auc is 0.7993896372126021, the best income_marital_auc_avg is 0.8761371399039481
+epoch: 2 step: 6235, loss is 0.5404471158981323
+Train epoch time: 17965.760 ms, per step time: 2.881 ms
+start infer...
+infer data finished, start eval...
+result : income_auc=0.9833082917947681, marital_auc=0.9176945078776066, use time 5s
+The best income_auc is 0.9833082917947681, the best marital_auc is 0.9176945078776066, the best income_marital_auc_avg is 0.9505013998361873
+epoch: 3 step: 6235, loss is 0.26600515842437744
+Train epoch time: 20357.339 ms, per step time: 3.265 ms
+start infer...
+infer data finished, start eval...
+result : income_auc=0.9843190639741299, marital_auc=0.9634857856721967, use time 4s
+The best income_auc is 0.9843190639741299, the best marital_auc is 0.9634857856721967, the best income_marital_auc_avg is 0.9739024248231634
+...
```
# 评估过程
@@ -292,6 +350,16 @@ Usage: bash run_standalone_eval_ascend.sh [DATA_PATH] [CKPT_PATH] [DEVICE_ID] [C
bash run_standalone_eval_ascend.sh /home/mmoe/data/ /home/mmoe/MMoE_train-50_6236.ckpt 1 /home/mmoe/default_config.yaml
```
+### CPU环境运行评估
+
+您可以通过python脚本开始进行评估
+
+```shell
+python eval.py --data_path ././data/mindrecord/ --ckpt_file ./ckpt/best_marital_auc.ckpt
+```
+
+其中././data/mindrecord/验证集路径,./ckpt/ckpt/best_marital_auc.ckpt为选择的最好ckpt文件。
+
## 结果
评估结果保存在示例路径中,您可在此路径下的日志找到如下结果:
@@ -302,6 +370,12 @@ bash run_standalone_eval_ascend.sh /home/mmoe/data/ /home/mmoe/MMoE_train-50_6
result: {'income_auc': 0.9969135802469136, 'marital_auc': 1.0}
```
+- cpu环境下使用census-income数据集评估MMoE
+
+```text
+result : income_auc=0.9872372448355503, marital_auc=0.9820659214506045
+```
+
# Ascend310推理过程
## 导出MindIR
@@ -359,6 +433,24 @@ bash run_infer_310.sh [MINDIR_PATH] [DATA_PATH] [NEED_PREPROCESS] [DEVICE_ID]
| 微调检查点 | 2.66MB(.ckpt文件) |893.8KB(.ckpt文件)|
| 脚本 | [链接](https://gitee.com/mindspore/models/tree/master/research/recommend/mmoe) |[链接](https://gitee.com/mindspore/models/tree/master/research/recommend/mmoe)|
+| 参数 | i5-10400 CPU |
+|---|---|
+| 模型版本 | MMoE |
+| 资源 | i5-10400 CPU 2.90GHz |
+| 上传日期 |2022-9-30 ; |
+| MindSpore版本 | 1.8.0 |
+| 数据集 | census-income |
+| 训练参数 | epoch=10, batch_size = 32 |
+| 优化器 | Adam |
+| 损失函数 | BCELoss |
+| 输出 | 概率 |
+| 损失 | 0.209593266248703 |
+|速度|3.437毫秒/步 |
+|总时长 | 17分钟 |
+|参数 | 23.55KB |
+|精度指标 | best income_auc:0.9872 best marital_auc:0.9820 |
+| 微调检查点 | 4.92 MB (.ckpt文件) |
+
# 随机情况说明
train.py中使用随机种子
diff --git a/research/recommend/mmoe/default_config_cpu.yaml b/research/recommend/mmoe/default_config_cpu.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..81fc981284a10827e2ddcfa7df2aab301ab3b260
--- /dev/null
+++ b/research/recommend/mmoe/default_config_cpu.yaml
@@ -0,0 +1,63 @@
+# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
+enable_modelarts: False
+data_url: ""
+train_url: ""
+checkpoint_url: ""
+data_path: "././data/mindrecord/"
+output_path: "././train"
+load_path: "/cache/checkpoint_path"
+checkpoint_path: './ckpt/'
+checkpoint_file: './ckpt/MMoE_train-8_6235.ckpt'
+device_target: CPU
+enable_profiling: False
+
+ckpt_path: "./ckpt/"
+ckpt_file: "./ckpt/MMoE_train-9_6235.ckpt"
+# ==============================================================================
+# Training options
+epoch_size: 20
+keep_checkpoint_max: 5
+learning_rate: 0.0001
+batch_size: 32
+num_features: 499
+num_experts: 8
+units: 4
+MINDIR_name: 'MMoE.MINDIR'
+ckpt_file_path: './MMoE_train-50_6236.ckpt'
+
+dataset_name: 'census_income'
+dataset_sink_mode: False
+run_distribute: False
+device_id: 0
+save_checkpoint: True
+save_checkpoint_epochs: 1
+lr: 0.0005
+local_data_path: '../data'
+
+# Model Description
+model_name: MMoE
+file_name: 'MMoE'
+file_format: 'MINDIR'
+
+# 'preprocess.'
+result_path: './preprocess_Result'
+
+# 'postprocess.'
+label1_path: './scripts/preprocess_Result/01_label1'
+label2_path: './scripts/preprocess_Result/02_label2'
+result_bin_path: './scripts/result_Files'
+income_path: './result_Files/income_output'
+marital_path: './result_Files/marital_output'
+---
+# Config description for each option
+enable_modelarts: 'Whether training on modelarts, default: False'
+data_url: 'Dataset url for obs'
+train_url: 'Training output url for obs'
+data_path: 'Dataset path for local'
+output_path: 'Training output path for local'
+
+device_target: 'Target device type'
+enable_profiling: 'Whether enable profiling while training, default: False'
+
+---
+device_target: ['Ascend', 'GPU', 'CPU']
diff --git a/research/recommend/mmoe/eval.py b/research/recommend/mmoe/eval.py
index 1700ca58d5178899cef63bebb9e356168f6c18bc..a22aff064acde71da72a8b2652a8a7ce4f3ca2a7 100644
--- a/research/recommend/mmoe/eval.py
+++ b/research/recommend/mmoe/eval.py
@@ -39,6 +39,7 @@ def modelarts_process():
@moxing_wrapper(pre_process=modelarts_process)
def eval_mmoe():
"""MMoE eval"""
+ modelarts_process()
device_num = get_device_num()
if device_num > 1:
context.set_context(mode=context.GRAPH_MODE,
diff --git a/research/recommend/mmoe/train.py b/research/recommend/mmoe/train.py
index 2f2235855c737914e38dc701cb6e093e3b4533dc..f84f81b069d790e44993f3d42a414e22eeb9338e 100644
--- a/research/recommend/mmoe/train.py
+++ b/research/recommend/mmoe/train.py
@@ -36,7 +36,7 @@ from src.model_utils.device_adapter import get_device_id, get_device_num, get_ra
from src.get_lr import get_lr
from src.callback import EvalCallBack
-set_seed(1)
+set_seed(5)
def modelarts_pre_process():
@@ -57,6 +57,19 @@ def run_train():
context.set_context(save_graphs=False)
device_num = get_device_num()
+ if device_target == 'CPU':
+ config.epoch = 10
+ config.lr = 0.0001
+
+ if config.run_distribute:
+ context.reset_auto_parallel_context()
+ context.set_auto_parallel_context(device_num=device_num,
+ parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
+ if device_target == "Ascend":
+ context.set_context(device_id=get_device_id())
+ init()
+ elif device_target == "GPU":
+ init()
if config.run_distribute:
context.reset_auto_parallel_context()
@@ -105,14 +118,17 @@ def run_train():
eps=1e-7,
weight_decay=0.0,
loss_scale=1.0)
- scale_update_cell = DynamicLossScaleUpdateCell(
- loss_scale_value=2 ** 12 if config.device_target == 'Ascend' else 1.0,
- scale_factor=2,
- scale_window=1000)
- train_net = TrainStepWrap(
- loss_net, opt, scale_update_cell, config.device_target)
- train_net.set_train()
- model = Model(train_net)
+ if device_target == 'CPU':
+ model = Model(loss_net, optimizer=opt)
+ else:
+ scale_update_cell = DynamicLossScaleUpdateCell(
+ loss_scale_value=2 ** 12 if config.device_target == 'Ascend' else 1.0,
+ scale_factor=2,
+ scale_window=1000)
+ train_net = TrainStepWrap(
+ loss_net, opt, scale_update_cell, config.device_target)
+ train_net.set_train()
+ model = Model(train_net)
time_cb = TimeMonitor()
loss_cb = LossMonitor()