From d1e69415cc1239cbcb178cc2cadf1f4f378a2c9a Mon Sep 17 00:00:00 2001 From: anzhengqi <anzhengqi1@huawei.com> Date: Mon, 26 Sep 2022 21:57:51 +0800 Subject: [PATCH] modify wide_and_deep_multitable network scipts --- .../wide_and_deep_multitable/src/datasets.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/official/recommend/wide_and_deep_multitable/src/datasets.py b/official/recommend/wide_and_deep_multitable/src/datasets.py index 1aec7cae9..bb40c7003 100644 --- a/official/recommend/wide_and_deep_multitable/src/datasets.py +++ b/official/recommend/wide_and_deep_multitable/src/datasets.py @@ -164,7 +164,6 @@ def _get_h5_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000): def _get_tf_dataset(data_dir, - schema_dict, input_shape_dict, train_mode=True, epochs=1, @@ -186,10 +185,16 @@ def _get_tf_dataset(data_dir, float_key_list = ["label", "continue_val"] - columns_list = [] - for key, attr_dict in schema_dict.items(): - print("key: {}; shape: {}".format(key, attr_dict["tf_shape"])) - columns_list.append(key) + columns_list = ["label", "continue_val", "indicator_id", "emb_128_id", + "emb_64_single_id", "multi_doc_event_category_id", + "multi_doc_event_category_id_mask", "multi_doc_ad_entity_id", + "multi_doc_ad_entity_id_mask", "multi_doc_event_entity_id", + "multi_doc_event_entity_id_mask", "multi_doc_ad_topic_id", + "multi_doc_ad_topic_id_mask", "multi_doc_ad_category_id", + "multi_doc_ad_category_id_mask", "multi_doc_event_topic_id", + "multi_doc_event_topic_id_mask", "ad_id", "display_ad_and_is_leak", + "display_id", "is_leak"] + for key in columns_list: if key in set(float_key_list): ms_dtype = mstype.float32 else: @@ -223,7 +228,6 @@ def _get_tf_dataset(data_dir, print("input_shape_dict start logging") print(input_shape_dict) print("input_shape_dict end logging") - print(schema_dict) def mixup(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u): a = np.asarray(a.reshape(batch_size,)) @@ -322,16 +326,11 @@ def create_dataset(data_dir, create_dataset """ if is_tf_dataset: - with open(os.path.join(data_dir, 'dataformat', "schema_dict.pkl"), - "rb") as file_in: - print(os.path.join(data_dir, 'dataformat', "schema_dict.pkl")) - schema_dict = pickle.load(file_in) with open( os.path.join(data_dir, 'dataformat', "input_shape_dict.pkl"), "rb") as file_in: input_shape_dict = pickle.load(file_in) return _get_tf_dataset(data_dir, - schema_dict, input_shape_dict, train_mode, epochs, -- GitLab