diff --git a/research/recommend/IntTower/README.md b/research/recommend/IntTower/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c7638f91b482609e52c854147151cc7cb09d9634
--- /dev/null
+++ b/research/recommend/IntTower/README.md
@@ -0,0 +1,194 @@
+
+# Contents
+
+- [Contents](#contents)
+- [IntTower Description](#IntTower-description)
+- [Dataset](#dataset)
+- [Environment Requirements](#environment-requirements)
+- [Quick Start](#quick-start)
+- [Script Description](#script-description)
+    - [Script and Sample Code](#script-and-sample-code)
+    - [Script Parameters](#script-parameters)
+    - [Training Process](#training-process)
+        - [Training](#training)
+    - [Evaluation Process](#evaluation-process)
+        - [Evaluation](#evaluation)
+    - [Inference Process](#inference-process)
+        - [Export MindIR](#export-mindir)
+- [Model Description](#model-description)
+    - [Performance](#performance)
+        - [Training Performance](#training-performance)
+        - [Inference Performance](#inference-performance)
+- [Description of Random Situation](#description-of-random-situation)
+- [ModelZoo Homepage](#modelzoo-homepage)
+
+# [IntTower Description](#contents)
+
+The proposed model, IntTower (short for Interaction enhanced Two-Tower), consists of Light-SE, FE-Block and CIR modules.
+Specifically, lightweight Light-SE module is used to identify the importance of different features and obtain refined feature representations in each tower. FE-Block module performs fine-grained and early feature interactions to capture the interactive signals between user and item towers explicitly and CIR module leverages a contrastive interaction regularization to further enhance the interactions implicitly.
+
+IntTower: the Next Generation of Two-Tower Model for
+Pre-Ranking System
+
+CIKM2022
+
+# [Dataset](#contents)
+
+- [Movie-Lens-1M](https://grouplens.org/datasets/movielens/1m/)
+
+# [Environment Requirements](#contents)
+
+- Hardware锛圕PU锛�
+    - Prepare hardware environment with CPU  processor.
+- Framework
+    - [MindSpore-1.8.1](https://www.mindspore.cn/install/en)
+- Requirements
+  - pandas
+  - numpy
+  - random
+  - mindspre==1.8.1
+  - tqdm
+  - sklearn
+- For more information, please check the resources below锛�
+  - [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html)
+  - [MindSpore Python API](https://www.mindspore.cn/docs/en/master/index.html)
+
+# [Quick Start](#contents)
+
+After installing MindSpore via the official website, you can start training and evaluation as follows:
+
+- running on CPU
+
+  ```python
+  # run training and evaluation example
+  python main.py
+  ```
+
+# [Script Description](#contents)
+
+## [Script and Sample Code](#contents)
+
+```bash
+.
+鈹斺攢IntTower
+  鈹溾攢README.md             # descriptions of warpctc
+  鈹溾攢eval.py               # model evaluation processing
+  鈹溾攢export.py             # export model to MindIR format
+  鈹溾攢get_dataset.py        # data process  
+  鈹溾攢model.py              # IntTower structure
+  鈹溾攢model_config.py       # model training parameters
+  鈹溾攢module.py             # modules in IntTower
+  鈹溾攢train.py              # train IntTower
+  鈹溾攢util.py               # some process function
+  鈹斺攢requirements.txt      # model requirements
+```
+
+## [Script Parameters](#contents)
+
+Parameters for both training and evaluation can be set in `model_config.py`
+
+- Parameters for Movielens-1M Dataset
+
+```python
+mlp_layers = [300, 300, 128]   # mlp units in every layer
+feblock_size = 256 # number of units in Fe-block
+head_num = 4 # number of pieces in Fe-block
+user_embedding_dim = 129 # size of user embedding
+item_embedding_dim = 33 # size of item embedding
+sparse_embedding_dim = 32 # size of single sparse feature embedding dim
+use_multi_layer = True # use every user layer
+user_sparse_field = 4 # number of user sparse feature
+keep_rate = 0.9 # dropout keep_rate
+epoch = 10 # training epoch
+batch_size = 2048 # training batch size
+seed = 3047 # random seed
+lr = 0.0005 # learn rate
+ ```
+
+## [Training Process](#contents)
+
+### Training
+
+- running on Ascend
+
+  ```python
+  python train.py > ms_log/output.log 2>&1 &
+  ```
+
+- The python command above will run in the background, you can view the results through the file `ms_log/output.log`.
+
+  ```txt
+   13%|鈻堚枎        | 31/230 [00:23<02:26,  1.36it/s, train_auc=0.813, train_loss=0.60894054]
+   ...
+  ```
+
+- The model checkpoint will be saved in the current directory.
+
+## [Evaluation Process](#contents)
+
+### Evaluation
+
+- evaluation on dataset
+
+  Before running the command below, please check the checkpoint path used for evaluation.
+
+  ```python
+  python eval.py > ms_log/eval_output.log 2>&1 &
+  ```
+
+  The above python command will run in the background. You can view the results through the file "eval_output.log". The accuracy is saved in auc.log file.
+
+  ```txt
+   [00:31,  2.29it/s, test_auc=0.896, test_loss=0.3207327]
+  ```
+
+## Inference Process
+
+### [Export MindIR](#contents)
+
+- Export on local
+
+  ```shell
+  python export.py
+  ```
+
+# [Model Description](#contents)
+
+## [Performance](#contents)
+
+### Training Performance
+
+| Parameters          | CPU                               |
+|---------------------|-----------------------------------|
+| Model Version       | IntTower                          |
+| Resource            | CPU 2.90GHz;16Core;32G Memory     |
+| uploaded Date       | 09/24/2022 (month/day/year)       |
+| MindSpore Version   | 1.8.1                             |
+| Dataset             | [1]                               |
+| Training Parameters | epoch=8, batch_size=2048, lr=1e-3 |
+| Optimizer           | Adam                              |
+| Loss Function       | Sigmoid Cross Entropy With Logits |
+| outputs             | AUC                               |
+| Loss                | 0.892                             |
+| Per Step Time       | 34.50 ms                          |
+
+### Inference Performance
+
+| Parameters        | CPU                           |
+|-------------------|-------------------------------|
+| Model Version     | IntTower                      |
+| Resource          | CPU 2.90GHz;16Core;32G Memory |                        |
+| Uploaded Date     | 09/24/2022 (month/day/year)   |
+| MindSpore Version | 1.8.1                         |
+| Dataset           | [1]                           |
+| batch_size        | 2048                          |
+| outputs           | AUC                           |
+| AUC               | 0.896                         |
+
+# [Description of Random Situation](#contents)
+
+We set the random seed before training in model_config.py.
+
+# [ModelZoo Homepage](#contents)
+
+ Please check the official [homepage](https://gitee.com/mindspore/models)
\ No newline at end of file
diff --git a/research/recommend/IntTower/eval.py b/research/recommend/IntTower/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..9542076999b231f1d30a22f0eb877d6ba423d769
--- /dev/null
+++ b/research/recommend/IntTower/eval.py
@@ -0,0 +1,35 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import mindspore as ms
+from mindspore import nn
+from model import IntTower
+from util import test_epoch, setup_seed
+from get_dataset import process_struct_data, construct_dataset
+import model_config as cfg
+
+if __name__ == '__main__':
+    seed = 2012
+    setup_seed(seed)
+    ms.set_context(mode=ms.PYNATIVE_MODE)
+    batch_size = cfg.batch_size
+    data_path = './data/movielens.txt'
+    _, _, test_generator = process_struct_data(data_path)
+    test_dataset = construct_dataset(test_generator, batch_size)
+    network = IntTower()
+    loss_fn = nn.BCELoss(reduction='mean')
+    param_dict = ms.load_checkpoint("./IntTower.ckpt")
+    ms.load_param_into_net(network, param_dict)
+    test_epoch(test_dataset, network, loss_fn, test_generator, batch_size)
diff --git a/research/recommend/IntTower/export.py b/research/recommend/IntTower/export.py
new file mode 100644
index 0000000000000000000000000000000000000000..99f1f7f8983e5cad4a229257b32677f5bcaf6108
--- /dev/null
+++ b/research/recommend/IntTower/export.py
@@ -0,0 +1,25 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import numpy as np
+import mindspore as ms
+from model import IntTower
+
+if __name__ == '__main__':
+    network = IntTower()
+    input_tensor = ms.Tensor(np.ones([2048, 7]).astype(np.float32))
+    param_dict = ms.load_checkpoint("./IntTower.ckpt")
+    ms.load_param_into_net(network, param_dict)
+    ms.export(net=network, inputs=input_tensor, file_name='./IntTower', file_format="MINDIR")
diff --git a/research/recommend/IntTower/get_dataset.py b/research/recommend/IntTower/get_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..a85e55945d43b0a52da6ec3a9e53a573de2b2887
--- /dev/null
+++ b/research/recommend/IntTower/get_dataset.py
@@ -0,0 +1,97 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+
+import pandas as pd
+import mindspore.dataset as ds
+from sklearn.preprocessing import LabelEncoder, MinMaxScaler
+from sklearn.model_selection import train_test_split
+from util import MyDataset
+
+
+def data_process(data_path):
+    data = pd.read_csv(data_path)
+    data = data.drop(data[data['rating'] == 3].index)
+    data['rating'] = data['rating'].apply(lambda x: 1 if x > 3 else 0)
+    data = data.sort_values(by='timestamp', ascending=True)
+    train, test = train_test_split(data, test_size=0.2)
+    train, valid = train_test_split(train, test_size=0.2)
+    return [train, valid, test, data]
+
+
+def create_dataset(data_set, batch_size=32):
+
+    dataset = ds.GeneratorDataset(data_set, column_names=['data', 'label'])
+    dataset = dataset.batch(batch_size)
+    return dataset
+
+
+def get_user_feature(data):
+    data_group = data[['user_id', 'rating']].groupby('user_id').agg('mean').reset_index()
+    data_group.rename(columns={'rating': 'user_mean_rating'}, inplace=True)
+    data = pd.merge(data_group, data, on='user_id')
+    return data
+
+
+def get_item_feature(data):
+    data_group = data[['movie_id', 'rating']].groupby('movie_id').agg('mean').reset_index()
+    data_group.rename(columns={'rating': 'item_mean_rating'}, inplace=True)
+    data = pd.merge(data_group, data, on='movie_id')
+    return data
+
+
+def process_struct_data(data_path):
+    data_list = data_process(data_path)
+    train, valid, test, data = data_list[0], data_list[1], data_list[2], data_list[3]
+
+    train = get_user_feature(train)
+    train = get_item_feature(train)
+
+    valid = get_user_feature(valid)
+    valid = get_item_feature(valid)
+
+    test = get_user_feature(test)
+    test = get_item_feature(test)
+
+    sparse_features = ['user_id', 'movie_id', 'gender', 'age', 'occupation']
+    dense_features = ['user_mean_rating', 'item_mean_rating']
+
+    for feat in sparse_features:
+        lbe = LabelEncoder()
+        lbe.fit(data[feat])
+        train[feat] = lbe.transform(train[feat])
+        valid[feat] = lbe.transform(valid[feat])
+        test[feat] = lbe.transform(test[feat])
+
+    mms = MinMaxScaler(feature_range=(0, 1))
+    mms.fit(train[dense_features])
+    mms.fit(valid[dense_features])
+    mms.fit(test[dense_features])
+    train[dense_features] = mms.transform(train[dense_features])
+    valid[dense_features] = mms.transform(valid[dense_features])
+    test[dense_features] = mms.transform(test[dense_features])
+
+    train_dataset_generator = MyDataset(train[sparse_features + dense_features],
+                                        train['rating'])
+    valid_dataset_generator = MyDataset(valid[sparse_features + dense_features],
+                                        valid['rating'])
+    test_dataset_generator = MyDataset(test[sparse_features + dense_features],
+                                       test['rating'])
+    return train_dataset_generator, valid_dataset_generator, test_dataset_generator
+
+
+def construct_dataset(dataset_generator, batch_size):
+    dataset = create_dataset(dataset_generator, batch_size)
+    return dataset
diff --git a/research/recommend/IntTower/model.py b/research/recommend/IntTower/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e898ea08830c491d253bc3774783100e5d2e89e
--- /dev/null
+++ b/research/recommend/IntTower/model.py
@@ -0,0 +1,170 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import mindspore as ms
+import mindspore.nn as nn
+from mindspore import ops
+from mindspore.ops import Reshape, ExpandDims, Transpose, matmul, concat
+from mindspore.nn import Embedding
+from module import LightSE
+import model_config as cfg
+
+class IntTower(nn.Cell):
+    """
+    IntTower Model Structure
+    """
+
+    def __init__(self):
+        super(IntTower, self).__init__()
+        self.mlp_layers = cfg.mlp_layers
+        self.use_multi_layer = cfg.use_multi_layer
+        self.activation = 'relu'
+        self.head_num = cfg.head_num
+        self.feblock_size = cfg.feblock_size
+        self.user_embedding_dim = cfg.user_embedding_dim
+
+        self.item_embedding_dim = cfg.item_embedding_dim
+        self.sparse_embedding_dim = cfg.sparse_embedding_dim
+        self.dropout = nn.Dropout(cfg.keep_rate)
+        self.User_SE = LightSE(cfg.user_sparse_field)
+        self.user_fe_embedding = None
+        self.item_fe_embedding = None
+
+        self.user_bn_list = []
+        self.item_bn_list = []
+
+        self.user_dense_layer_1 = nn.Dense(self.user_embedding_dim, self.mlp_layers[0], weight_init='normal',
+                                           activation=self.activation)
+        self.user_dense_layer_2 = nn.Dense(self.mlp_layers[0], self.mlp_layers[1], weight_init='normal',
+                                           activation=self.activation)
+        self.user_dense_layer_3 = nn.Dense(self.mlp_layers[1], self.mlp_layers[2], weight_init='normal',
+                                           activation=self.activation)
+
+        self.user_fe_layer_1 = nn.Dense(self.mlp_layers[0], self.feblock_size, weight_init='normal',
+                                        activation=self.activation)
+        self.user_fe_layer_2 = nn.Dense(self.mlp_layers[1], self.feblock_size, weight_init='normal',
+                                        activation=self.activation)
+        self.user_fe_layer_3 = nn.Dense(self.mlp_layers[2], self.feblock_size, weight_init='normal',
+                                        activation=self.activation)
+
+        self.item_dense_layer_1 = nn.Dense(self.item_embedding_dim, self.mlp_layers[0], weight_init='normal',
+                                           activation=self.activation)
+        self.item_dense_layer_2 = nn.Dense(self.mlp_layers[0], self.mlp_layers[1], weight_init='normal',
+                                           activation=self.activation)
+        self.item_dense_layer_3 = nn.Dense(self.mlp_layers[1], self.mlp_layers[2], weight_init='normal',
+                                           activation=self.activation)
+
+        self.user_bn_layer_1 = nn.BatchNorm1d(self.mlp_layers[0])
+        self.user_bn_layer_2 = nn.BatchNorm1d(self.mlp_layers[1])
+        self.user_bn_layer_3 = nn.BatchNorm1d(self.mlp_layers[2])
+
+        self.item_bn_layer_1 = nn.BatchNorm1d(self.mlp_layers[0])
+        self.item_bn_layer_2 = nn.BatchNorm1d(self.mlp_layers[1])
+        self.item_bn_layer_3 = nn.BatchNorm1d(self.mlp_layers[2])
+
+        self.user_dense_list = [self.user_dense_layer_1, self.user_dense_layer_2, self.user_dense_layer_3]
+        self.item_dense_list = [self.item_dense_layer_1, self.item_dense_layer_2, self.item_dense_layer_3]
+        self.user_fe_list = [self.user_fe_layer_1, self.user_fe_layer_2, self.user_fe_layer_3]
+
+        self.user_bn_list = [self.user_bn_layer_1, self.user_bn_layer_2, self.user_bn_layer_3]
+        self.item_bn_list = [self.item_bn_layer_1, self.item_bn_layer_2, self.item_bn_layer_3]
+
+        self.user_fe_dense = nn.Dense(self.mlp_layers[-1], self.feblock_size)
+        self.item_fe_dense = nn.Dense(self.mlp_layers[-1], self.feblock_size)
+        self.user_id_embedding = Embedding(6040, self.sparse_embedding_dim)
+        self.gender_embedding = Embedding(2, self.sparse_embedding_dim)
+        self.age_embedding = Embedding(7, self.sparse_embedding_dim)
+        self.occupation_embedding = Embedding(21, self.sparse_embedding_dim)
+        self.item_id_embedding = Embedding(3668, self.sparse_embedding_dim)
+        self.user_bacth_norm = nn.BatchNorm1d(self.user_embedding_dim)
+        self.item_bacth_norm = nn.BatchNorm1d(self.item_embedding_dim)
+
+    def construct(self, inputs):
+        user_list, user_dense_list = [], []
+        item_list, item_dense_list = [], []
+        user_id = ms.Tensor(inputs[:, 0:1], dtype=ms.int32)
+        user_list.append(self.user_id_embedding(user_id))
+        user_gender = ms.Tensor(inputs[:, 2:3], dtype=ms.int32)
+        user_list.append(self.gender_embedding(user_gender))
+        user_age = ms.Tensor(inputs[:, 3:4], dtype=ms.int32)
+        user_list.append(self.age_embedding(user_age))
+        user_occ = ms.Tensor(inputs[:, 4:5], dtype=ms.int32)
+        user_list.append(self.occupation_embedding(user_occ))
+
+        item_id = ms.Tensor(inputs[:, 1:2], dtype=ms.int32)
+        item_list.append(self.item_id_embedding(item_id))
+
+        user_mean_rating = ms.Tensor(inputs[:, 5:6], dtype=ms.float32)
+        item_mean_rating = ms.Tensor(inputs[:, 6:7], dtype=ms.float32)
+
+        user_dense_list.append(user_mean_rating)
+        item_dense_list.append(item_mean_rating)
+
+        user_sparse_embedding = concat(user_list, axis=1)
+        user_sparse_embedding = self.User_SE(user_sparse_embedding)
+
+        user_sparse_input = ops.flatten(user_sparse_embedding)
+
+        item_sparse_embedding = concat(item_list)
+        item_sparse_input = ops.flatten(item_sparse_embedding)
+
+        user_dense_input = concat(user_dense_list)
+        item_dense_input = concat(item_dense_list)
+
+        user_input = concat([user_sparse_input, user_dense_input], axis=-1)
+        item_input = concat([item_sparse_input, item_dense_input], axis=-1)
+
+        user_embed = self.user_bacth_norm(user_input)
+        item_embed = self.item_bacth_norm(item_input)
+        user_fe_reps = []
+
+        for i in range(len(self.mlp_layers)):
+            user_embed = self.dropout(self.user_bn_list[i](
+                self.user_dense_list[i](user_embed)))
+            if self.use_multi_layer:
+                user_fe_rep = self.user_fe_list[i](user_embed)
+                user_fe_reps.append(user_fe_rep)
+
+            item_embed = self.dropout(self.item_bn_list[i](
+                self.item_dense_list[i](item_embed)))
+        item_fe_rep = self.item_fe_dense(item_embed)
+
+        if self.use_multi_layer:
+            score = []
+            for i in range(len(user_fe_reps)):
+                user_temp = Reshape()(user_fe_reps[i],
+                                      (-1, self.head_num, int(user_fe_reps[i].shape[1] // self.head_num)))
+                item_temp = Reshape()(item_fe_rep, (-1, self.head_num, int(item_fe_rep.shape[1] // self.head_num)))
+                item_temp = Transpose()((item_temp), (0, 2, 1))
+                dot_col = matmul(user_temp, item_temp)
+                max_col = dot_col.max(axis=2)
+                sum_col = max_col.sum(axis=1)
+                expand_col = ExpandDims()(sum_col, 1)
+                score.append(expand_col)
+            model_output = concat(score, axis=1).sum(axis=1)
+            model_output = nn.Sigmoid()(Reshape()(model_output, (-1, 1)))
+        else:
+            user_fe_rep = self.user_fe_dense(user_embed)
+            user_temp = Reshape()(user_fe_rep, (-1, self.head_num, int(user_fe_rep.shape[1] // self.head_num)))
+            item_temp = Reshape()(item_fe_rep, (-1, self.head_num, int(item_fe_rep.shape[1] // self.head_num)))
+            item_temp = Transpose()((item_temp), (0, 2, 1))
+            dot_col = matmul(user_temp, item_temp)
+            max_col = dot_col.max(axis=2)
+            sum_col = max_col.sum(axis=1)
+            expand_col = ExpandDims()(sum_col, 1)
+            score = expand_col
+            model_output = nn.Sigmoid()(score)
+
+        return model_output
diff --git a/research/recommend/IntTower/model_config.py b/research/recommend/IntTower/model_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e93d013e90759263438340642be520d5f2e1b706
--- /dev/null
+++ b/research/recommend/IntTower/model_config.py
@@ -0,0 +1,28 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+mlp_layers = [300, 300, 128]
+feblock_size = 256
+head_num = 4
+user_embedding_dim = 129
+item_embedding_dim = 33
+sparse_embedding_dim = 32
+use_multi_layer = True
+user_sparse_field = 4
+keep_rate = 0.9
+epoch = 10
+batch_size = 2048
+seed = 3047
+lr = 0.0005
diff --git a/research/recommend/IntTower/module.py b/research/recommend/IntTower/module.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd2b361745ca7be8f662c2dfcedcb447d966e871
--- /dev/null
+++ b/research/recommend/IntTower/module.py
@@ -0,0 +1,78 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+
+from mindspore import nn, ops
+from mindspore.ops import L2Normalize, broadcast_to
+
+
+class LightSE(nn.Cell):
+    """LightSELayer used in IntTower.
+      Input shape
+        - A list of 3D tensor with shape: ``(batch_size,filed_size,embedding_size)``.
+      Output shape
+        - A list of 3D tensor with shape: ``(batch_size,filed_size,embedding_size)``.
+      Arguments
+        - **filed_size** : Positive integer, number of feature groups.
+        - **seed** : A Python integer to use as random seed.
+      References
+      """
+
+    def __init__(self, field_size, seed=1024):
+        super(LightSE, self).__init__()
+        self.seed = seed
+        self.softmax = nn.Softmax(axis=1)
+        self.field_size = field_size
+        self.excitation = nn.Dense(self.field_size, self.field_size)
+
+    def construct(self, inputs):
+        if len(inputs.shape) != 3:
+            raise ValueError(
+                "Unexpected inputs dimensions %d, expect to be 3 dimensions" % (len(inputs.shape)))
+        Z = ops.mean(inputs, axis=-1)
+        A = self.excitation(Z)  # (batch,reduction_size)
+        A = self.softmax(A)  # (batch,reduction_size)
+        out = inputs * ops.expand_dims(A, axis=2)
+
+        return inputs + out
+
+
+class ContrastLoss(nn.LossBase):
+    def __init__(self, reduction="mean"):
+        """compute contrast loss
+              Input shape
+        - Two 2D tensors with shape: ``(batch_size,embedding_size)``.
+      Output shape
+        - A loss scalar.
+
+        """
+
+        super(ContrastLoss, self).__init__(reduction)
+        self.norm = L2Normalize(axis=-1)
+        self.cos_sim = nn.CosineEmbeddingLoss()
+        self.abs = ops.Abs()
+        self.lam = 1
+        self.pos = 0
+        self.all = 0
+        self.tau = 1
+
+    def construct(self, user_embedding, item_embedding, target):
+        user_embedding = self.norm(user_embedding)
+        item_embedding = self.norm(item_embedding)
+        pos_index = broadcast_to(target, (target.shape[0], item_embedding.shape[1]))
+        self.pos += self.abs(ops.mean(user_embedding * item_embedding * pos_index)) / self.tau
+        self.all += self.abs(ops.mean(user_embedding * item_embedding)) / self.tau
+        contrast_loss = -ops.log(ops.exp(self.pos) / ops.exp(self.all)) * self.lam
+        return contrast_loss
diff --git a/research/recommend/IntTower/requirements.txt b/research/recommend/IntTower/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..36966f30d000d1bc6b199235fbc4abc3a9cab8dd
--- /dev/null
+++ b/research/recommend/IntTower/requirements.txt
@@ -0,0 +1,6 @@
+pandas
+numpy
+random
+mindspre==1.8.1
+tqdm
+sklearn
\ No newline at end of file
diff --git a/research/recommend/IntTower/train.py b/research/recommend/IntTower/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..2832600aecd48f263a8477a2db305eb532a9f027
--- /dev/null
+++ b/research/recommend/IntTower/train.py
@@ -0,0 +1,77 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+
+import model_config as cfg
+import mindspore as ms
+from mindspore import nn
+from tqdm import tqdm
+from sklearn.metrics import roc_auc_score
+from model import IntTower
+from util import AvgMeter, valid_epoch, setup_seed, train_epoch
+from get_dataset import process_struct_data
+from get_dataset import construct_dataset
+
+
+def forward_fn(inputs, targets):
+    logits = network(inputs)
+    Loss = loss_fn(logits, targets.reshape(-1, 1))
+    n_logit = logits.asnumpy()
+    n_target = targets.reshape(-1, 1).asnumpy()
+    Auc = roc_auc_score(n_target.astype(int), n_logit)
+    return Loss, Auc
+
+
+if __name__ == '__main__':
+    print("1")
+
+    ms.set_context(mode=ms.PYNATIVE_MODE)
+    epoch = cfg.epoch
+    batch_size = cfg.batch_size
+    seed = cfg.seed
+    lr = cfg.lr
+
+    setup_seed(seed)
+
+    data_path = './data/movielens.txt'
+    train_dataset_generator, valid_dataset_generator, _ = process_struct_data(data_path)
+    train_dataset = construct_dataset(train_dataset_generator, batch_size)
+    valid_dataset = construct_dataset(valid_dataset_generator, batch_size)
+
+    network = IntTower()
+    loss_fn = nn.BCELoss(reduction='mean')
+    net_opt = nn.Adam(network.trainable_params(), learning_rate=lr)
+
+    best_loss = float('inf')
+    best_auc = 0
+    for i in range(epoch):
+        loss_meter = AvgMeter()
+        auc_meter = AvgMeter()
+        tqdm_object = tqdm(train_dataset.create_dict_iterator(), total=len(train_dataset_generator) // batch_size)
+        print("epoch %d :" % (i))
+        train_loss = 0
+        count = 0
+        for batch in tqdm_object:
+            loss, auc = train_epoch(forward_fn,
+                                    batch["data"], batch["label"], net_opt)
+            count = len(batch["label"])
+            loss_meter.update(loss, count)
+            auc_meter.update(auc, count)
+            tqdm_object.set_postfix(train_loss=loss_meter.avg, train_auc=auc_meter.avg)
+
+        valid_loss, valid_auc = valid_epoch(valid_dataset, network, loss_fn, valid_dataset_generator, batch_size)
+        if valid_auc.avg > best_auc:
+            best_auc = valid_auc.avg
+            ms.save_checkpoint(network, "./IntTower.ckpt")
diff --git a/research/recommend/IntTower/util.py b/research/recommend/IntTower/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cb8a55a029df7957245b94f854c41a94905e3a4
--- /dev/null
+++ b/research/recommend/IntTower/util.py
@@ -0,0 +1,108 @@
+# Copyright 2022 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+
+import random
+import numpy as np
+import mindspore as ms
+from mindspore.ops import value_and_grad, set_seed
+from tqdm import tqdm
+from sklearn.metrics import roc_auc_score
+
+
+class AvgMeter:
+    def __init__(self, name="Metric"):
+        self.name = name
+        self.reset()
+
+    def reset(self):
+        self.avg, self.sum, self.count = [0] * 3
+
+    def update(self, val, count=1):
+        self.count += count
+        self.sum += val * count
+        self.avg = self.sum / self.count
+
+    def __repr__(self):
+        text = f"{self.name}: {self.avg:.4f}"
+        return text
+
+
+class MyDataset:
+
+    def __init__(self, data, label):
+        self.data = data
+        self.label = label
+
+    def __getitem__(self, index):
+        return self.data.iloc[index], ms.Tensor(self.label.iloc[index], dtype=ms.float32)
+
+    def __len__(self):
+        return len(self.label)
+
+
+def test_epoch(test_dataset, model, loss_fn, test_dataset_generator, batch_size):
+    loss_meter = AvgMeter()
+    auc_meter = AvgMeter()
+    tqdm_object = tqdm(test_dataset.create_dict_iterator(), total=len(test_dataset_generator) // batch_size)
+    for batch in tqdm_object:
+        logits = model(batch["data"])
+        loss = loss_fn(logits, batch["label"].reshape(-1, 1))
+        n_logit = logits.asnumpy()
+        n_target = batch["label"].reshape(-1, 1).asnumpy()
+        auc = roc_auc_score(n_target.astype(int), n_logit)
+        count = len(batch["label"])
+        loss_meter.update(loss, count)
+        auc_meter.update(auc, count)
+        tqdm_object.set_postfix(test_loss=loss_meter.avg, test_auc=auc_meter.avg)
+    return loss_meter, auc_meter
+
+
+def valid_epoch(valid_dataset, model, loss_fn, valid_dataset_generator, batch_size):
+    loss_meter = AvgMeter()
+    auc_meter = AvgMeter()
+    tqdm_object = tqdm(valid_dataset.create_dict_iterator(), total=len(valid_dataset_generator) // batch_size)
+    for batch in tqdm_object:
+        logits = model(batch["data"])
+        loss = loss_fn(logits, batch["label"].reshape(-1, 1))
+        n_logit = logits.asnumpy()
+        n_target = batch["label"].reshape(-1, 1).asnumpy()
+        auc = roc_auc_score(n_target.astype(int), n_logit)
+        count = len(batch["label"])
+        loss_meter.update(loss, count)
+        auc_meter.update(auc, count)
+        tqdm_object.set_postfix(valid_loss=loss_meter.avg, valid_auc=auc_meter.avg)
+    return loss_meter, auc_meter
+
+
+def setup_seed(seed):
+    np.random.seed(seed)
+    random.seed(seed)
+    set_seed(seed)
+
+
+def forward_fn(network, inputs, targets, loss_fn):
+    logits = network(inputs)
+    loss = loss_fn(logits, targets.reshape(-1, 1))
+    n_logit = logits.asnumpy()
+    n_target = targets.reshape(-1, 1).asnumpy()
+    auc = roc_auc_score(n_target.astype(int), n_logit)
+    return loss, auc
+
+
+def train_epoch(forward, inputs, targets, net_opt):
+    grad_fn = value_and_grad(forward, None, net_opt.parameters)
+    (loss, auc), grads = grad_fn(inputs, targets)
+    net_opt(grads)
+    return loss, auc