Skip to content
Snippets Groups Projects
Unverified Commit 2fca42b7 authored by Bowen Chen's avatar Bowen Chen Committed by GitHub
Browse files

move generator test into ops folder to accelerate tests (#5472)


* move generator test into ops folder to accelerate tests

* ensure ONEFLOW_TEST_CPU_ONLY

* auto format by CI

Co-authored-by: default avataroneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: default avataroneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
parent fc056b3a
No related branches found
No related tags found
No related merge requests found
......@@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import unittest
import oneflow as flow
......@@ -20,14 +21,16 @@ import oneflow as flow
class TestGenerator(flow.unittest.TestCase):
def test_different_devices(test_case):
auto_gen = flow.Generator(device="auto")
cuda_gen = flow.Generator(device="cuda")
cpu_gen = flow.Generator(device="cpu")
test_case.assertTrue(auto_gen.initial_seed(), cuda_gen.initial_seed())
test_case.assertTrue(auto_gen.initial_seed(), cpu_gen.initial_seed())
with test_case.assertRaises(Exception) as context:
flow.Generator(device="invalid")
test_case.assertTrue("unimplemented" in str(context.exception))
if not os.getenv("ONEFLOW_TEST_CPU_ONLY"):
cuda_gen = flow.Generator(device="cuda")
test_case.assertTrue(auto_gen.initial_seed(), cuda_gen.initial_seed())
def test_generator_manual_seed(test_case):
generator = flow.Generator()
generator.manual_seed(1)
......@@ -41,19 +44,20 @@ class TestDefaultGenerator(flow.unittest.TestCase):
global_seed = 10
flow.manual_seed(10)
auto_gen = flow.default_generator(device="auto")
cuda_gen = flow.default_generator(device="cuda")
cuda0_gen = flow.default_generator(device="cuda:0")
cpu_gen = flow.default_generator(device="cpu")
for gen in [auto_gen, cuda_gen, cuda0_gen, cpu_gen]:
test_gens = [auto_gen, cpu_gen]
if not os.getenv("ONEFLOW_TEST_CPU_ONLY"):
cuda_gen = flow.default_generator(device="cuda")
cuda0_gen = flow.default_generator(device="cuda:0")
test_gens += [cuda_gen, cuda0_gen]
for gen in test_gens:
test_case.assertTrue(gen.initial_seed() == global_seed)
def test_different_devices(test_case):
auto_gen = flow.default_generator(device="auto")
cuda_gen = flow.default_generator(device="cuda")
cuda0_gen = flow.default_generator(device="cuda:0")
cpu_gen = flow.default_generator(device="cpu")
for gen in [cuda_gen, cuda0_gen, cpu_gen]:
test_case.assertTrue(auto_gen.initial_seed() == gen.initial_seed())
with test_case.assertRaises(Exception) as context:
flow.default_generator(device="invalid")
......@@ -63,26 +67,45 @@ class TestDefaultGenerator(flow.unittest.TestCase):
flow.default_generator(device="cpu:1000")
test_case.assertTrue("check_failed" in str(context.exception))
with test_case.assertRaises(Exception) as context:
flow.default_generator(device="cuda:1000")
test_case.assertTrue("check_failed" in str(context.exception))
test_gens = [cpu_gen]
if not os.getenv("ONEFLOW_TEST_CPU_ONLY"):
with test_case.assertRaises(Exception) as context:
flow.default_generator(device="cuda:1000")
test_case.assertTrue("check_failed" in str(context.exception))
cuda_gen = flow.default_generator(device="cuda")
cuda0_gen = flow.default_generator(device="cuda:0")
test_gens += [cuda_gen, cuda0_gen]
for gen in test_gens:
test_case.assertTrue(auto_gen.initial_seed() == gen.initial_seed())
def test_generator_manual_seed(test_case):
auto_gen = flow.default_generator(device="auto")
cuda_gen = flow.default_generator(device="cuda")
cpu_gen = flow.default_generator(device="cpu")
test_gens = [auto_gen, cpu_gen]
if not os.getenv("ONEFLOW_TEST_CPU_ONLY"):
cuda_gen = flow.default_generator(device="cuda")
cuda0_gen = flow.default_generator(device="cuda:0")
test_gens += [cuda_gen, cuda0_gen]
for seed in [1, 2]:
auto_gen.manual_seed(seed)
for gen in [auto_gen, cuda_gen, cpu_gen]:
for gen in test_gens:
test_case.assertTrue(gen.initial_seed() == seed)
def test_generator_seed(test_case):
auto_gen = flow.default_generator(device="auto")
cuda_gen = flow.default_generator(device="cuda")
cpu_gen = flow.default_generator(device="cpu")
for gen in [auto_gen, cuda_gen, cpu_gen]:
test_gens = [auto_gen, cpu_gen]
if not os.getenv("ONEFLOW_TEST_CPU_ONLY"):
cuda_gen = flow.default_generator(device="cuda")
cuda0_gen = flow.default_generator(device="cuda:0")
test_gens += [cuda_gen, cuda0_gen]
for gen in test_gens:
seed = gen.seed()
test_case.assertTrue(seed == gen.initial_seed())
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment