diff --git a/oneflow/python/test/generator/test_generator.py b/oneflow/python/test/ops/generator/test_generator.py similarity index 64% rename from oneflow/python/test/generator/test_generator.py rename to oneflow/python/test/ops/generator/test_generator.py index 94294360edea0c8a47b23f0a5153947c7c556343..56e2f1629e6bcd2919ff8dae4ac6ba6e200bda5d 100644 --- a/oneflow/python/test/generator/test_generator.py +++ b/oneflow/python/test/ops/generator/test_generator.py @@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ +import os import unittest import oneflow as flow @@ -20,14 +21,16 @@ import oneflow as flow class TestGenerator(flow.unittest.TestCase): def test_different_devices(test_case): auto_gen = flow.Generator(device="auto") - cuda_gen = flow.Generator(device="cuda") cpu_gen = flow.Generator(device="cpu") - test_case.assertTrue(auto_gen.initial_seed(), cuda_gen.initial_seed()) test_case.assertTrue(auto_gen.initial_seed(), cpu_gen.initial_seed()) with test_case.assertRaises(Exception) as context: flow.Generator(device="invalid") test_case.assertTrue("unimplemented" in str(context.exception)) + if not os.getenv("ONEFLOW_TEST_CPU_ONLY"): + cuda_gen = flow.Generator(device="cuda") + test_case.assertTrue(auto_gen.initial_seed(), cuda_gen.initial_seed()) + def test_generator_manual_seed(test_case): generator = flow.Generator() generator.manual_seed(1) @@ -41,19 +44,20 @@ class TestDefaultGenerator(flow.unittest.TestCase): global_seed = 10 flow.manual_seed(10) auto_gen = flow.default_generator(device="auto") - cuda_gen = flow.default_generator(device="cuda") - cuda0_gen = flow.default_generator(device="cuda:0") cpu_gen = flow.default_generator(device="cpu") - for gen in [auto_gen, cuda_gen, cuda0_gen, cpu_gen]: + + test_gens = [auto_gen, cpu_gen] + if not os.getenv("ONEFLOW_TEST_CPU_ONLY"): + cuda_gen = flow.default_generator(device="cuda") + cuda0_gen = flow.default_generator(device="cuda:0") + test_gens += [cuda_gen, cuda0_gen] + + for gen in test_gens: test_case.assertTrue(gen.initial_seed() == global_seed) def test_different_devices(test_case): auto_gen = flow.default_generator(device="auto") - cuda_gen = flow.default_generator(device="cuda") - cuda0_gen = flow.default_generator(device="cuda:0") cpu_gen = flow.default_generator(device="cpu") - for gen in [cuda_gen, cuda0_gen, cpu_gen]: - test_case.assertTrue(auto_gen.initial_seed() == gen.initial_seed()) with test_case.assertRaises(Exception) as context: flow.default_generator(device="invalid") @@ -63,26 +67,45 @@ class TestDefaultGenerator(flow.unittest.TestCase): flow.default_generator(device="cpu:1000") test_case.assertTrue("check_failed" in str(context.exception)) - with test_case.assertRaises(Exception) as context: - flow.default_generator(device="cuda:1000") - test_case.assertTrue("check_failed" in str(context.exception)) + test_gens = [cpu_gen] + if not os.getenv("ONEFLOW_TEST_CPU_ONLY"): + with test_case.assertRaises(Exception) as context: + flow.default_generator(device="cuda:1000") + test_case.assertTrue("check_failed" in str(context.exception)) + + cuda_gen = flow.default_generator(device="cuda") + cuda0_gen = flow.default_generator(device="cuda:0") + test_gens += [cuda_gen, cuda0_gen] + + for gen in test_gens: + test_case.assertTrue(auto_gen.initial_seed() == gen.initial_seed()) def test_generator_manual_seed(test_case): auto_gen = flow.default_generator(device="auto") - cuda_gen = flow.default_generator(device="cuda") cpu_gen = flow.default_generator(device="cpu") + test_gens = [auto_gen, cpu_gen] + if not os.getenv("ONEFLOW_TEST_CPU_ONLY"): + cuda_gen = flow.default_generator(device="cuda") + cuda0_gen = flow.default_generator(device="cuda:0") + test_gens += [cuda_gen, cuda0_gen] + for seed in [1, 2]: auto_gen.manual_seed(seed) - for gen in [auto_gen, cuda_gen, cpu_gen]: + for gen in test_gens: test_case.assertTrue(gen.initial_seed() == seed) def test_generator_seed(test_case): auto_gen = flow.default_generator(device="auto") - cuda_gen = flow.default_generator(device="cuda") cpu_gen = flow.default_generator(device="cpu") - for gen in [auto_gen, cuda_gen, cpu_gen]: + test_gens = [auto_gen, cpu_gen] + if not os.getenv("ONEFLOW_TEST_CPU_ONLY"): + cuda_gen = flow.default_generator(device="cuda") + cuda0_gen = flow.default_generator(device="cuda:0") + test_gens += [cuda_gen, cuda0_gen] + + for gen in test_gens: seed = gen.seed() test_case.assertTrue(seed == gen.initial_seed())