From 2f595bee1cc540fd22b10fb8f92bc5f9485a14b9 Mon Sep 17 00:00:00 2001
From: wangbixing123 <wangbixing@huawei.com>
Date: Sat, 18 Jun 2022 11:04:42 +0800
Subject: [PATCH] fix bug for ecapa_tdnn

---
 .../audio/ecapa_tdnn/ecapa-tdnn_config.yaml    | 18 +++++++++---------
 official/audio/ecapa_tdnn/train.py             |  4 +++-
 2 files changed, 12 insertions(+), 10 deletions(-)

diff --git a/official/audio/ecapa_tdnn/ecapa-tdnn_config.yaml b/official/audio/ecapa_tdnn/ecapa-tdnn_config.yaml
index 131cd29fe..381b62ba1 100644
--- a/official/audio/ecapa_tdnn/ecapa-tdnn_config.yaml
+++ b/official/audio/ecapa_tdnn/ecapa-tdnn_config.yaml
@@ -28,33 +28,33 @@ run_distribute: 0
 
 # ==============================================================================
 # options
-in_channels: 80                                                  # input channel size, same as the dim of fbank feature
+in_channels: 80                                                 # input channel size, same as the dim of fbank feature
 channels: 1024                                                  # channel size of middle layer feature map
 base_lrate: 0.000001                                            # base learning rate of cyclic LR
 max_lrate: 0.0001                                               # max learning rate of cyclic LR
 momentum: 0.95                                                  # weight decay for optimizer
-weightDecay: 0.000002                                           # momentum for optimizer
-num_epochs: 2                                                   # training epoch
+weight_decay: 0.000002                                          # momentum for optimizer
+num_epochs: 3                                                   # training epoch
 minibatch_size: 192                                             # batch size
 emb_size: 192                                                   # embedding dim
 step_size: 65000                                                # steps to achieve max learning rate cyclic LR
 class_num: 7205                                                 # speaker num pf voxceleb1&2 
 pre_trained: False                                              # if pre-trained model exist
 
-train_data_path: "/Temp/abc000/feat_train/"                     # path to fbank training data
+train_data_path: "/Temp/abc/feat_train/"                        # path to fbank training data
 keep_checkpoint_max: 20                           
 checkpoint_path: "train_ecapa_vox2_full-2_664204.ckpt"          # path to pre-trained model
 ckpt_save_dir: "./ckpt/"                                        # path to store train model
 
 # eval
-eval_data_path: "/home/abc000/feat_eval/"                       # path to eval fbank data
-veri_file_path: /home/abc000/feat_eval/veri_test_bleeched.txt   # trials
+eval_data_path: "/Temp/abc/feat_eval/"                          # path to eval fbank data
+veri_file_path: /Temp/abc/feat_eval/veri_test_bleeched.txt      # trials
 cut_wav: false                                                  # cut wav to 3s (cut wav to 3s, same as train data)
-model_path: "ckpt/train_ecapa_vox2_full-2_664204.ckpt"          # path of eval model
-train_norm_path: "/home/abc000/feat_norm/"                      # fbank data for norm 
+model_path: "/Temp/abc/ckpt/train_ecapa_vox12-1_660204.ckpt"    # path of eval model
+train_norm_path: "/Temp/abc/feat_norm/"                         # fbank data for norm 
 score_norm: "s-norm"                                            # if do norm, uncomment this two line
 cohort_size: 20000                                              # max number of utts to do norm
-npy_file_path: './npys/'                                        # dir to save intermediate result
+npy_file_path: '/Temp/abc/npys/'                                # dir to save intermediate result
 
 # # export option
 exp_ckpt_file: './train_ecapa_vox2_full-2_664204.ckpt'
diff --git a/official/audio/ecapa_tdnn/train.py b/official/audio/ecapa_tdnn/train.py
index c28e18f6d..955325d2e 100644
--- a/official/audio/ecapa_tdnn/train.py
+++ b/official/audio/ecapa_tdnn/train.py
@@ -50,7 +50,9 @@ def create_dataset(cfg, data_home, shuffle=False):
     """
 
     dataset_generator = DatasetGenerator(data_home)
-    distributed_sampler = DistributedSampler(len(dataset_generator), cfg.group_size, cfg.rank, shuffle=True)
+    distributed_sampler = None
+    if cfg.run_distribute:
+        distributed_sampler = DistributedSampler(len(dataset_generator), cfg.group_size, cfg.rank, shuffle=True)
     vox2_ds = ds.GeneratorDataset(dataset_generator, ["data", "label"], shuffle=shuffle, sampler=distributed_sampler)
     cnt = int(len(dataset_generator) / cfg.group_size)
     return vox2_ds, cnt
-- 
GitLab