diff --git a/research/nlp/textrcnn/train.py b/research/nlp/textrcnn/train.py index 9d0c3cde303901a226aa9d7eeb6bd318b2593efe..73c1e963912316e2c7fffd5b2c6faacfdc20acc0 100644 --- a/research/nlp/textrcnn/train.py +++ b/research/nlp/textrcnn/train.py @@ -52,7 +52,7 @@ def run_train(): device_id = get_device_id() context.set_context(device_id=device_id) - if cfg.device_target == "GPU": + if cfg.device_target == "GPU" and cfg.cell == "lstm": context.set_context(enable_graph_kernel=True) if cfg.preprocess == 'true':