Skip to content
Snippets Groups Projects
Commit 83ff0336 authored by scxfjiang's avatar scxfjiang Committed by Shenghang Tsai
Browse files

matmul (#2084)

* matmul

* np.allclose
parent 6c500602
No related branches found
No related tags found
No related merge requests found
......@@ -8,38 +8,6 @@ import oneflow.core.register.logical_blob_id_pb2 as logical_blob_id_util
from oneflow.python.oneflow_export import oneflow_export
@oneflow_export('keras.maths.matmul')
def matmul( a,
b,
transpose_a=False,
transpose_b=False,
adjoint_a=False,
adjoint_b=False,
a_is_sparse=False,
b_is_sparse=False,
name=None):
assert adjoint_a == False
assert adjoint_b == False
assert a_is_sparse == False
assert b_is_sparse == False
op_conf = op_conf_util.OperatorConf()
if name is None:
op_conf.name = id_util.UniqueStr('Matmul_')
else:
op_conf.name = name
op_conf.matmul_conf.a = a.logical_blob_name
op_conf.matmul_conf.b = b.logical_blob_name
op_conf.matmul_conf.transpose_a = transpose_a
op_conf.matmul_conf.transpose_b = transpose_b
op_conf.matmul_conf.out = "out"
compile_context.CurJobAddOp(op_conf)
lbi = logical_blob_id_util.LogicalBlobId()
lbi.op_name = op_conf.name
lbi.blob_name = "out"
return remote_blob_util.RemoteBlob(lbi)
@oneflow_export('keras.maths.add')
def add(x,
......
from __future__ import absolute_import
import oneflow.python.framework.compile_context as compile_context
import oneflow.python.framework.remote_blob as remote_blob_util
import oneflow.python.framework.id_util as id_util
import oneflow.core.operator.op_conf_pb2 as op_conf_util
import oneflow.core.register.logical_blob_id_pb2 as logical_blob_id_util
from oneflow.python.oneflow_export import oneflow_export
# TODO: Support export multiple interfaces, eg. @oneflow_export("matmul", "linalg.matmul")
@oneflow_export("matmul")
def matmul(a, b, transpose_a=False, transpose_b=False, name=None):
op_conf = op_conf_util.OperatorConf()
setattr(op_conf, "name", name if name is not None else id_util.UniqueStr("Matmul_"))
setattr(op_conf.matmul_conf, "a", a.logical_blob_name)
setattr(op_conf.matmul_conf, "b", b.logical_blob_name)
setattr(op_conf.matmul_conf, "transpose_a", transpose_a)
setattr(op_conf.matmul_conf, "transpose_b", transpose_b)
setattr(op_conf.matmul_conf, "out", "out")
compile_context.CurJobAddOp(op_conf)
out_lbi = logical_blob_id_util.LogicalBlobId()
setattr(out_lbi, "op_name", op_conf.name)
setattr(out_lbi, "blob_name", "out")
return remote_blob_util.RemoteBlob(out_lbi)
import oneflow as flow
import numpy as np
import torch
config = flow.ConfigProtoBuilder()
config.gpu_device_num(1)
flow.init(config)
def MatmulJob(a = flow.input_blob_def((4, 5)), b = flow.input_blob_def((5,4))):
job_conf = flow.get_cur_job_conf_builder()
job_conf.batch_size(4).data_part_num(1).default_data_type(flow.float)
return flow.keras.maths.matmul(a, b)
flow.add_job(MatmulJob)
a = np.arange(1,21).reshape((4,5)).astype(np.float32)
b = np.arange(1,21).reshape((5,4)).astype(np.float32)
with flow.Session() as sess:
x = sess.run(MatmulJob, a, b).get()
y = torch.matmul(torch.Tensor(a), torch.Tensor(b))
result = np.isclose(np.array(x), y.numpy(), rtol=1e-03, atol=1e-05)
for i in result.ravel():
assert i, "the matmul test is wrong!"
import tensorflow as tf
import oneflow as flow
import numpy as np
tf.enable_eager_execution()
assert tf.executing_eagerly()
def test_matmul(a_shape, b_shape, transpose_a=False, transpose_b=False):
a = np.random.random_sample(a_shape).astype(np.float32)
b = np.random.random_sample(b_shape).astype(np.float32)
# OneFlow
config = flow.ConfigProtoBuilder()
config.gpu_device_num(1)
flow.init(config)
def MatmulTestJob(a=flow.input_blob_def(a_shape), b=flow.input_blob_def(b_shape)):
job_conf = flow.get_cur_job_conf_builder()
job_conf.batch_size(1).data_part_num(1).default_data_type(flow.float)
return flow.matmul(a, b, transpose_a, transpose_b)
flow.add_job(MatmulTestJob)
with flow.Session() as sess:
of_out = sess.run(MatmulTestJob, a, b).get()
# TensorFlow
tf_out = tf.matmul(tf.Variable(a), tf.Variable(b), transpose_a, transpose_b).numpy()
assert np.allclose(of_out, tf_out, atol=1e-7)
# run one example each time
if __name__ == "__main__":
test_matmul(a_shape=(512, 256), b_shape=(256, 1024))
# test_matmul(a_shape=(256, 512), b_shape=(256, 1024), transpose_a=True)
# test_matmul(a_shape=(512, 256), b_shape=(1024, 256), transpose_b=True)
# test_matmul(
# a_shape=(256, 512), b_shape=(1024, 256), transpose_a=True, transpose_b=True
# )
# test_matmul(a_shape=(10, 10, 64, 32), b_shape=(10, 10, 32, 128))
# test_matmul(a_shape=(10, 10, 32, 64), b_shape=(10, 10, 32, 128), transpose_a=True)
# test_matmul(a_shape=(10, 10, 64, 32), b_shape=(10, 10, 128, 32), transpose_b=True)
# test_matmul(
# a_shape=(10, 10, 32, 64),
# b_shape=(10, 10, 128, 32),
# transpose_a=True,
# transpose_b=True,
# )
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment