diff --git a/research/nlp/textrcnn/train.py b/research/nlp/textrcnn/train.py
index 9d0c3cde303901a226aa9d7eeb6bd318b2593efe..73c1e963912316e2c7fffd5b2c6faacfdc20acc0 100644
--- a/research/nlp/textrcnn/train.py
+++ b/research/nlp/textrcnn/train.py
@@ -52,7 +52,7 @@ def run_train():
 
     device_id = get_device_id()
     context.set_context(device_id=device_id)
-    if cfg.device_target == "GPU":
+    if cfg.device_target == "GPU" and cfg.cell == "lstm":
         context.set_context(enable_graph_kernel=True)
 
     if cfg.preprocess == 'true':