Skip to content
Snippets Groups Projects
Commit d1e69415 authored by anzhengqi's avatar anzhengqi
Browse files

modify wide_and_deep_multitable network scipts

parent 03f7c052
No related branches found
No related tags found
No related merge requests found
......@@ -164,7 +164,6 @@ def _get_h5_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000):
def _get_tf_dataset(data_dir,
schema_dict,
input_shape_dict,
train_mode=True,
epochs=1,
......@@ -186,10 +185,16 @@ def _get_tf_dataset(data_dir,
float_key_list = ["label", "continue_val"]
columns_list = []
for key, attr_dict in schema_dict.items():
print("key: {}; shape: {}".format(key, attr_dict["tf_shape"]))
columns_list.append(key)
columns_list = ["label", "continue_val", "indicator_id", "emb_128_id",
"emb_64_single_id", "multi_doc_event_category_id",
"multi_doc_event_category_id_mask", "multi_doc_ad_entity_id",
"multi_doc_ad_entity_id_mask", "multi_doc_event_entity_id",
"multi_doc_event_entity_id_mask", "multi_doc_ad_topic_id",
"multi_doc_ad_topic_id_mask", "multi_doc_ad_category_id",
"multi_doc_ad_category_id_mask", "multi_doc_event_topic_id",
"multi_doc_event_topic_id_mask", "ad_id", "display_ad_and_is_leak",
"display_id", "is_leak"]
for key in columns_list:
if key in set(float_key_list):
ms_dtype = mstype.float32
else:
......@@ -223,7 +228,6 @@ def _get_tf_dataset(data_dir,
print("input_shape_dict start logging")
print(input_shape_dict)
print("input_shape_dict end logging")
print(schema_dict)
def mixup(a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r, s, t, u):
a = np.asarray(a.reshape(batch_size,))
......@@ -322,16 +326,11 @@ def create_dataset(data_dir,
create_dataset
"""
if is_tf_dataset:
with open(os.path.join(data_dir, 'dataformat', "schema_dict.pkl"),
"rb") as file_in:
print(os.path.join(data_dir, 'dataformat', "schema_dict.pkl"))
schema_dict = pickle.load(file_in)
with open(
os.path.join(data_dir, 'dataformat', "input_shape_dict.pkl"),
"rb") as file_in:
input_shape_dict = pickle.load(file_in)
return _get_tf_dataset(data_dir,
schema_dict,
input_shape_dict,
train_mode,
epochs,
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment