diff --git a/official/audio/ecapa_tdnn/ecapa-tdnn_config.yaml b/official/audio/ecapa_tdnn/ecapa-tdnn_config.yaml index 131cd29feda6db85769f585965a2b6fa719ff887..381b62ba1a8a4dd9ab368fbc7b604f90105292d7 100644 --- a/official/audio/ecapa_tdnn/ecapa-tdnn_config.yaml +++ b/official/audio/ecapa_tdnn/ecapa-tdnn_config.yaml @@ -28,33 +28,33 @@ run_distribute: 0 # ============================================================================== # options -in_channels: 80 # input channel size, same as the dim of fbank feature +in_channels: 80 # input channel size, same as the dim of fbank feature channels: 1024 # channel size of middle layer feature map base_lrate: 0.000001 # base learning rate of cyclic LR max_lrate: 0.0001 # max learning rate of cyclic LR momentum: 0.95 # weight decay for optimizer -weightDecay: 0.000002 # momentum for optimizer -num_epochs: 2 # training epoch +weight_decay: 0.000002 # momentum for optimizer +num_epochs: 3 # training epoch minibatch_size: 192 # batch size emb_size: 192 # embedding dim step_size: 65000 # steps to achieve max learning rate cyclic LR class_num: 7205 # speaker num pf voxceleb1&2 pre_trained: False # if pre-trained model exist -train_data_path: "/Temp/abc000/feat_train/" # path to fbank training data +train_data_path: "/Temp/abc/feat_train/" # path to fbank training data keep_checkpoint_max: 20 checkpoint_path: "train_ecapa_vox2_full-2_664204.ckpt" # path to pre-trained model ckpt_save_dir: "./ckpt/" # path to store train model # eval -eval_data_path: "/home/abc000/feat_eval/" # path to eval fbank data -veri_file_path: /home/abc000/feat_eval/veri_test_bleeched.txt # trials +eval_data_path: "/Temp/abc/feat_eval/" # path to eval fbank data +veri_file_path: /Temp/abc/feat_eval/veri_test_bleeched.txt # trials cut_wav: false # cut wav to 3s (cut wav to 3s, same as train data) -model_path: "ckpt/train_ecapa_vox2_full-2_664204.ckpt" # path of eval model -train_norm_path: "/home/abc000/feat_norm/" # fbank data for norm +model_path: "/Temp/abc/ckpt/train_ecapa_vox12-1_660204.ckpt" # path of eval model +train_norm_path: "/Temp/abc/feat_norm/" # fbank data for norm score_norm: "s-norm" # if do norm, uncomment this two line cohort_size: 20000 # max number of utts to do norm -npy_file_path: './npys/' # dir to save intermediate result +npy_file_path: '/Temp/abc/npys/' # dir to save intermediate result # # export option exp_ckpt_file: './train_ecapa_vox2_full-2_664204.ckpt' diff --git a/official/audio/ecapa_tdnn/train.py b/official/audio/ecapa_tdnn/train.py index c28e18f6d79455c6af29db9cf9c57fa31eba383f..955325d2eed40f27ed852343c46cd103bd128a58 100644 --- a/official/audio/ecapa_tdnn/train.py +++ b/official/audio/ecapa_tdnn/train.py @@ -50,7 +50,9 @@ def create_dataset(cfg, data_home, shuffle=False): """ dataset_generator = DatasetGenerator(data_home) - distributed_sampler = DistributedSampler(len(dataset_generator), cfg.group_size, cfg.rank, shuffle=True) + distributed_sampler = None + if cfg.run_distribute: + distributed_sampler = DistributedSampler(len(dataset_generator), cfg.group_size, cfg.rank, shuffle=True) vox2_ds = ds.GeneratorDataset(dataset_generator, ["data", "label"], shuffle=shuffle, sampler=distributed_sampler) cnt = int(len(dataset_generator) / cfg.group_size) return vox2_ds, cnt