Skip to content
Snippets Groups Projects
Commit 7778df95 authored by lixinqi's avatar lixinqi
Browse files

add RemoteBlob.shape and RemoteBlob.dtype

parent 47716d11
No related branches found
No related tags found
No related merge requests found
......@@ -3,39 +3,32 @@ from __future__ import absolute_import
import oneflow.core.common.data_type_pb2 as data_type_util
class BlobDesc(object):
def __init__(self, shape,
dtype = data_type_util.kFloat,
has_batch_dim = True,
is_dynamic = False,
split_axis = None,
broadcast = None):
self.shape_ = shape
self.dtype_ = dtype
self.has_batch_dim_ = has_batch_dim
self.is_dynamic_ = is_dynamic
self.split_axis_ = split_axis
self.broadcast_ = broadcast
def __init__(self):
pass
@property
def shape(self): return self.static_shape
@property
def shape(self):
return self.shape_
def static_shape(self):
raise NotImplementedError
@property
def dtype(self):
return self.dtype_
raise NotImplementedError
@property
def has_batch_dim(self):
return self.has_batch_dim_
raise NotImplementedError
@property
def is_dynamic(self):
return self.is_dynamic_
raise NotImplementedError
@property
def split_axis(self):
return self.split_axis_
raise NotImplementedError
@property
def broadcast(self):
return self.broadcast_
raise NotImplementedError
......@@ -103,7 +103,7 @@ def JobBuildAndInferCtx_GetStaticShape(job_name, lbn):
error = text_format.Parse(error_str, error_util.Error())
if error.HasField("error_type"): raise JobBuildAndInferError(error)
int_list = text_format.Parse(axis_str, record_util.Int64List())
return tuple(int_list)
return tuple(map(int, int_list.value))
def JobBuildAndInferCtx_GetDataType(job_name, lbn):
job_name = str(job_name)
......@@ -111,7 +111,7 @@ def JobBuildAndInferCtx_GetDataType(job_name, lbn):
dtype, erro_str = oneflow_internal.JobBuildAndInferCtx_GetDataType(job_name, lbn)
error = text_format.Parse(erro_str, error_util.Error())
if error.HasField("error_type"): raise JobBuildAndInferError(error)
return dtype
return int(dtype)
def JobBuildAndInferCtx_GetHasBatchDim(job_name, lbn):
job_name = str(job_name)
......
......@@ -17,12 +17,39 @@ class input_blob_def(blob_desc.BlobDesc):
split_axis = None,
broadcast = None):
if split_axis == None and broadcast == None: split_axis = 0
blob_desc.BlobDesc.__init__(
self, shape, dtype, has_batch_dim, is_dynamic, split_axis, broadcast)
assert type(shape) is tuple
for dim in shape: assert type(dim) is int
self.shape_ = shape
self.dtype_ = dtype
self.has_batch_dim_ = has_batch_dim
self.is_dynamic_ = is_dynamic
self.split_axis_ = split_axis
self.broadcast_ = broadcast
self.lbi_ = lbi_util.LogicalBlobId()
self.lbi_.op_name = id_util.UniqueStr("Input_")
self.lbi_.blob_name = "out"
@property
def static_shape(self): return self.shape_
@property
def shape(self): return self.shape_
@property
def dtype(self): return self.dtype_
@property
def has_batch_dim(self): return self.has_batch_dim_
@property
def split_axis(self): return self.split_axis_
@property
def broadcast(self): return self.broadcast_
@property
def is_dynamic(self): return self.is_dynamic_
@property
def lbi(self): return self.lbi_
......
......@@ -45,6 +45,9 @@ def GetStaticShape(job_name, lbn):
def GetDataType(job_name, lbn):
return c_api_util.JobBuildAndInferCtx_GetDataType(job_name, lbn)
def GetHasBatchDim(job_name, lbn):
return c_api_util.JobBuildAndInferCtx_GetHasBatchDim(job_name, lbn)
def GetHasSplitDimFromProducerView(job_name, lbn):
return c_api_util.JobBuildAndInferCtx_GetHasSplitDimFromProducerView(job_name, lbn)
......
......@@ -2,20 +2,24 @@ from __future__ import absolute_import
import oneflow.python.framework.blob_desc as blob_desc
import oneflow.python.framework.inter_user_job_util as inter_user_job_util
import oneflow.python.framework.job_builder as job_builder
import oneflow.core.common.data_type_pb2 as data_type_util
import oneflow
class RemoteBlob(blob_desc.BlobDesc):
def __init__(self, lbi,
shape = None,
dtype = data_type_util.kFloat,
has_batch_dim = True,
is_dynamic = False,
split_axis = None,
broadcast = None):
blob_desc.BlobDesc.__init__(
self, shape, dtype, has_batch_dim, is_dynamic, split_axis, broadcast)
def __init__(self, lbi):
self.job_name_ = job_builder.GetCurCtxJobName()
self.lbi_ = lbi
self.lbn_ = lbi.op_name + "/" + lbi.blob_name
@property
def static_shape(self): return job_builder.GetStaticShape(self.job_name_, self.lbn_)
@property
def dtype(self): return job_builder.GetDataType(self.job_name_, self.lbn_)
@property
def has_batch_dim(self): return job_builder.GetHasBatchDim(self.job_name_, self.lbn_)
@property
def op_name(self):
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment