Skip to content
Snippets Groups Projects
Commit d546807b authored by 张宪顺's avatar 张宪顺
Browse files

Update 第二阶段/Model_file/模型训练使用使用说明.md, 第二阶段/Model_file/main.py files

parent 4a6cdfbe
No related branches found
No related tags found
No related merge requests found
......@@ -11,9 +11,26 @@ from tensorflow.keras.optimizers import *
from tensorflow.keras.datasets import *
from pathlib import Path
MANIFEST_DIR = sys.path[0] + "\\train.csv"
Mode_DIR = sys.path[0] + "\\models"
isExists=os.path.exists(Mode_DIR)
# 判断结果
if not isExists:
# 如果不存在则创建目录
# 创建目录操作函数
os.makedirs(Mode_DIR)
print(Mode_DIR + ' 创建成功')
else:
# 如果目录存在则不创建,并提示目录已存在
print(Mode_DIR + ' 目录已存在')
MANIFEST_DIR = sys.path[0] + "\\train.csv"
Batch_size = 20
Long = 792
......@@ -58,25 +75,25 @@ def xs_gen(path=MANIFEST_DIR,batch_size = Batch_size,train=True,Lens=Lens,flag =
batch_y = np.array([convert2oneHot(label,10) for label in batch_list[:,-1]])
yield batch_x, batch_y
TEST_MANIFEST_DIR = sys.path[0] + "\\test_data.csv"
# TEST_MANIFEST_DIR = sys.path[0] + "\\test_data.csv"
def ts_gen(path=TEST_MANIFEST_DIR,batch_size = Batch_size):
# def ts_gen(path=TEST_MANIFEST_DIR,batch_size = Batch_size):
img_list = pd.read_csv(path)
# img_list = pd.read_csv(path)
img_list = np.array(img_list)[:Lens]
print("Found %s train items."%len(img_list))
print("list 1 is",img_list[0,-1])
steps = math.ceil(len(img_list) / batch_size) # 确定每轮有多少个batch
# img_list = np.array(img_list)[:Lens]
# print("Found %s train items."%len(img_list))
# print("list 1 is",img_list[0,-1])
# steps = math.ceil(len(img_list) / batch_size) # 确定每轮有多少个batch
for i in range(steps):
# for i in range(steps):
batch_list = img_list[i * batch_size : i * batch_size + batch_size]
#np.random.shuffle(batch_list)
batch_x = np.array([file for file in batch_list[:,1:]])
#batch_y = np.array([convert2oneHot(label,10) for label in batch_list[:,-1]])
# batch_list = img_list[i * batch_size : i * batch_size + batch_size]
# #np.random.shuffle(batch_list)
# batch_x = np.array([file for file in batch_list[:,1:]])
# #batch_y = np.array([convert2oneHot(label,10) for label in batch_list[:,-1]])
yield batch_x
# yield batch_x
......@@ -124,7 +141,7 @@ if __name__ == "__main__":
ckpt = tensorflow.keras.callbacks.ModelCheckpoint(
filepath= r"C:\Users\years\Desktop\learn_some\model\'best_model.{epoch:02d}-{val_loss:.4f}.h5",
filepath= Mode_DIR + "\\best_model.{epoch:02d}-{val_loss:.4f}.h5",
monitor='val_loss', save_best_only=True,verbose=1)
model = build_model()
......@@ -143,7 +160,7 @@ if __name__ == "__main__":
callbacks=[ckpt],
)
keras_file = r"C:\Users\years\Desktop\learn_some\model\finishModel.h5"
keras_file = Mode_DIR + "\\finishModel.h5"
model.save(keras_file, save_format="h5")
......@@ -171,7 +188,7 @@ if __name__ == "__main__":
tflite_model = converter.convert()
tflite_file = Path(r"C:/Users/years/Desktop/learn_some/model/finishModel.tflite")
tflite_file = Path(Mode_DIR + "\\finishModel.tflite")
tflite_file.write_bytes(tflite_model)
print("convert model to tflite done...")
# ######
......@@ -210,7 +227,7 @@ if __name__ == "__main__":
keras_file_tflite = "C:/Users/years/Desktop/learn_some/model/finishModel.tflite"
keras_file_tflite = Mode_DIR + "\\finishModel.tflite"
interpreter = tensorflow.lite.Interpreter(model_path=keras_file_tflite)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
......
环境说明:
​ vscode
​ tensorflow 2.5.0 - GPU
​ CUDA Version: 11.4
​ cuDNN:8.2.2
cuDNN:8.2.2
请将训练数据与本程序放在同一目录下
模型部分路径使用的绝对路径,使用时还请做出修改
本程序适用于文件夹下的数据,如若使用其他数据请另作修改
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment