From d1e69415cc1239cbcb178cc2cadf1f4f378a2c9a Mon Sep 17 00:00:00 2001
From: anzhengqi <anzhengqi1@huawei.com>
Date: Mon, 26 Sep 2022 21:57:51 +0800
Subject: [PATCH] modify wide_and_deep_multitable network scipts

---
 .../wide_and_deep_multitable/src/datasets.py  | 21 +++++++++----------
 1 file changed, 10 insertions(+), 11 deletions(-)

diff --git a/official/recommend/wide_and_deep_multitable/src/datasets.py b/official/recommend/wide_and_deep_multitable/src/datasets.py
index 1aec7cae9..bb40c7003 100644
--- a/official/recommend/wide_and_deep_multitable/src/datasets.py
+++ b/official/recommend/wide_and_deep_multitable/src/datasets.py
@@ -164,7 +164,6 @@ def _get_h5_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000):
 
 
 def _get_tf_dataset(data_dir,
-                    schema_dict,
                     input_shape_dict,
                     train_mode=True,
                     epochs=1,
@@ -186,10 +185,16 @@ def _get_tf_dataset(data_dir,
 
     float_key_list = ["label", "continue_val"]
 
-    columns_list = []
-    for key, attr_dict in schema_dict.items():
-        print("key: {}; shape: {}".format(key, attr_dict["tf_shape"]))
-        columns_list.append(key)
+    columns_list = ["label", "continue_val", "indicator_id", "emb_128_id",
+                    "emb_64_single_id", "multi_doc_event_category_id",
+                    "multi_doc_event_category_id_mask", "multi_doc_ad_entity_id",
+                    "multi_doc_ad_entity_id_mask", "multi_doc_event_entity_id",
+                    "multi_doc_event_entity_id_mask", "multi_doc_ad_topic_id",
+                    "multi_doc_ad_topic_id_mask", "multi_doc_ad_category_id",
+                    "multi_doc_ad_category_id_mask", "multi_doc_event_topic_id",
+                    "multi_doc_event_topic_id_mask", "ad_id", "display_ad_and_is_leak",
+                    "display_id", "is_leak"]
+    for key in columns_list:
         if key in set(float_key_list):
             ms_dtype = mstype.float32
         else:
@@ -223,7 +228,6 @@ def _get_tf_dataset(data_dir,
     print("input_shape_dict start logging")
     print(input_shape_dict)
     print("input_shape_dict end logging")
-    print(schema_dict)
 
     def mixup(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u):
         a = np.asarray(a.reshape(batch_size,))
@@ -322,16 +326,11 @@ def create_dataset(data_dir,
     create_dataset
     """
     if is_tf_dataset:
-        with open(os.path.join(data_dir, 'dataformat', "schema_dict.pkl"),
-                  "rb") as file_in:
-            print(os.path.join(data_dir, 'dataformat', "schema_dict.pkl"))
-            schema_dict = pickle.load(file_in)
         with open(
                 os.path.join(data_dir, 'dataformat', "input_shape_dict.pkl"),
                 "rb") as file_in:
             input_shape_dict = pickle.load(file_in)
         return _get_tf_dataset(data_dir,
-                               schema_dict,
                                input_shape_dict,
                                train_mode,
                                epochs,
-- 
GitLab