diff --git a/research/nlp/lstm_crf/train.py b/research/nlp/lstm_crf/train.py
index 8af7e3266d4cbbb1f22b8c191a0dfe8a6045ca3a..3046b41bbf64c8cfeb4461c77c1985f5cab9b77c 100644
--- a/research/nlp/lstm_crf/train.py
+++ b/research/nlp/lstm_crf/train.py
@@ -62,6 +62,9 @@ class EvalCabllBack(TimeMonitor):
         self.callback = F1(len(self.tags_to_index_map))
         self._best_val_F1 = 0
 
+    def epoch_begin(self, run_context):
+        self.network.is_training = True
+
     def epoch_end(self, run_context):
         """save .ckpt files"""
         self.network.is_training = False