diff --git a/research/nlp/lstm_crf/train.py b/research/nlp/lstm_crf/train.py index 8af7e3266d4cbbb1f22b8c191a0dfe8a6045ca3a..3046b41bbf64c8cfeb4461c77c1985f5cab9b77c 100644 --- a/research/nlp/lstm_crf/train.py +++ b/research/nlp/lstm_crf/train.py @@ -62,6 +62,9 @@ class EvalCabllBack(TimeMonitor): self.callback = F1(len(self.tags_to_index_map)) self._best_val_F1 = 0 + def epoch_begin(self, run_context): + self.network.is_training = True + def epoch_end(self, run_context): """save .ckpt files""" self.network.is_training = False