Skip to content
Snippets Groups Projects
Commit 2f595bee authored by wangbixing123's avatar wangbixing123
Browse files

fix bug for ecapa_tdnn

parent c1bb6424
No related branches found
No related tags found
No related merge requests found
...@@ -33,28 +33,28 @@ channels: 1024 # channel size o ...@@ -33,28 +33,28 @@ channels: 1024 # channel size o
base_lrate: 0.000001 # base learning rate of cyclic LR base_lrate: 0.000001 # base learning rate of cyclic LR
max_lrate: 0.0001 # max learning rate of cyclic LR max_lrate: 0.0001 # max learning rate of cyclic LR
momentum: 0.95 # weight decay for optimizer momentum: 0.95 # weight decay for optimizer
weightDecay: 0.000002 # momentum for optimizer weight_decay: 0.000002 # momentum for optimizer
num_epochs: 2 # training epoch num_epochs: 3 # training epoch
minibatch_size: 192 # batch size minibatch_size: 192 # batch size
emb_size: 192 # embedding dim emb_size: 192 # embedding dim
step_size: 65000 # steps to achieve max learning rate cyclic LR step_size: 65000 # steps to achieve max learning rate cyclic LR
class_num: 7205 # speaker num pf voxceleb1&2 class_num: 7205 # speaker num pf voxceleb1&2
pre_trained: False # if pre-trained model exist pre_trained: False # if pre-trained model exist
train_data_path: "/Temp/abc000/feat_train/" # path to fbank training data train_data_path: "/Temp/abc/feat_train/" # path to fbank training data
keep_checkpoint_max: 20 keep_checkpoint_max: 20
checkpoint_path: "train_ecapa_vox2_full-2_664204.ckpt" # path to pre-trained model checkpoint_path: "train_ecapa_vox2_full-2_664204.ckpt" # path to pre-trained model
ckpt_save_dir: "./ckpt/" # path to store train model ckpt_save_dir: "./ckpt/" # path to store train model
# eval # eval
eval_data_path: "/home/abc000/feat_eval/" # path to eval fbank data eval_data_path: "/Temp/abc/feat_eval/" # path to eval fbank data
veri_file_path: /home/abc000/feat_eval/veri_test_bleeched.txt # trials veri_file_path: /Temp/abc/feat_eval/veri_test_bleeched.txt # trials
cut_wav: false # cut wav to 3s (cut wav to 3s, same as train data) cut_wav: false # cut wav to 3s (cut wav to 3s, same as train data)
model_path: "ckpt/train_ecapa_vox2_full-2_664204.ckpt" # path of eval model model_path: "/Temp/abc/ckpt/train_ecapa_vox12-1_660204.ckpt" # path of eval model
train_norm_path: "/home/abc000/feat_norm/" # fbank data for norm train_norm_path: "/Temp/abc/feat_norm/" # fbank data for norm
score_norm: "s-norm" # if do norm, uncomment this two line score_norm: "s-norm" # if do norm, uncomment this two line
cohort_size: 20000 # max number of utts to do norm cohort_size: 20000 # max number of utts to do norm
npy_file_path: './npys/' # dir to save intermediate result npy_file_path: '/Temp/abc/npys/' # dir to save intermediate result
# # export option # # export option
exp_ckpt_file: './train_ecapa_vox2_full-2_664204.ckpt' exp_ckpt_file: './train_ecapa_vox2_full-2_664204.ckpt'
......
...@@ -50,6 +50,8 @@ def create_dataset(cfg, data_home, shuffle=False): ...@@ -50,6 +50,8 @@ def create_dataset(cfg, data_home, shuffle=False):
""" """
dataset_generator = DatasetGenerator(data_home) dataset_generator = DatasetGenerator(data_home)
distributed_sampler = None
if cfg.run_distribute:
distributed_sampler = DistributedSampler(len(dataset_generator), cfg.group_size, cfg.rank, shuffle=True) distributed_sampler = DistributedSampler(len(dataset_generator), cfg.group_size, cfg.rank, shuffle=True)
vox2_ds = ds.GeneratorDataset(dataset_generator, ["data", "label"], shuffle=shuffle, sampler=distributed_sampler) vox2_ds = ds.GeneratorDataset(dataset_generator, ["data", "label"], shuffle=shuffle, sampler=distributed_sampler)
cnt = int(len(dataset_generator) / cfg.group_size) cnt = int(len(dataset_generator) / cfg.group_size)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment