Skip to content
Snippets Groups Projects
Commit 80a6ff59 authored by caifubi's avatar caifubi
Browse files

Add JitConfig for LSTM

parent 8ced7bf4
No related branches found
No related tags found
No related merge requests found
......@@ -33,7 +33,7 @@ from mindspore.train.callback import LossMonitor, CheckpointConfig, ModelCheckpo
from mindspore.train.serialization import load_param_into_net, load_checkpoint
from mindspore.communication.management import init, get_rank
from mindspore.context import ParallelMode
from mindspore.common import JitConfig
set_seed(1)
......@@ -107,6 +107,7 @@ def train_lstm():
opt = nn.Momentum(network.trainable_params(), lr, config.momentum)
loss_cb = LossMonitor()
network.set_jit_config(JitConfig(jit_level="O2"))
model = Model(network, loss, opt, {'acc': Accuracy()})
print("============== Starting Training ==============")
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment