Skip to content
Snippets Groups Projects
Commit 1b4edcfc authored by willzhang4a58's avatar willzhang4a58
Browse files

DataContentDesc

parent 60328eea
No related branches found
No related tags found
No related merge requests found
......@@ -33,6 +33,50 @@ void CopyFromFirstToOtherBlobs(DeviceCtx* ctx, std::function<Blob*(const std::st
FOR_RANGE(size_t, i, 1, bns.size()) { (BnInOp2Blob(bns.Get(i))->*Copy)(ctx, blob_0); }
}
class DataContentDesc final {
public:
OF_DISALLOW_COPY_AND_MOVE(DataContentDesc);
DataContentDesc() = delete;
~DataContentDesc() = default;
DataContentDesc(std::function<Blob*(const std::string&)> BnInOp2Blob,
const PbRpf<std::string>* bns, int32_t axis) {
BnInOp2Blob_ = BnInOp2Blob;
seg_num_ = BnInOp2Blob(bns->Get(0))->shape().Count(0, axis);
elem_sum_.assign(bns->size(), 0);
FOR_RANGE(size_t, i, 0, elem_sum_.size()) {
elem_sum_[i] = BnInOp2Blob(bns->Get(i))->shape().Count(axis);
if (i > 0) { elem_sum_[i] += elem_sum_[i - 1]; }
}
bns_ = bns;
axis_ = axis;
}
std::tuple<int64_t, char*> CalcContinuousElemNumStartFrom(int64_t idx) {
std::tuple<int64_t, char*> ret(0, nullptr);
int64_t seg_idx = idx / elem_sum_.back();
int64_t idx_in_seg = idx % elem_sum_.back();
auto elem_sum_it = std::upper_bound(elem_sum_.begin(), elem_sum_.end(), idx_in_seg);
CHECK(elem_sum_it != elem_sum_.end());
std::get<0>(ret) = *elem_sum_it - idx_in_seg;
int64_t bn_idx = elem_sum_it - elem_sum_.begin();
int64_t idx_in_blob = idx_in_seg;
if (bn_idx > 0) { idx_in_blob -= elem_sum_[bn_idx - 1]; }
Blob* blob = BnInOp2Blob_(bns_->Get(bn_idx));
std::get<1>(ret) = blob->mut_dptr<char>()
+ (seg_idx * blob->shape().Count(axis_) + idx_in_blob)
* GetSizeOfDataType(blob->data_type());
return ret;
}
private:
std::function<Blob*(const std::string&)> BnInOp2Blob_;
int64_t seg_num_;
std::vector<int64_t> elem_sum_;
const PbRpf<std::string>* bns_;
int32_t axis_;
};
void ConcatSplitDataContent(DeviceCtx* ctx, std::function<Blob*(const std::string&)> BnInOp2Blob,
const PbRpf<std::string>& concat_bns, int32_t concat_axis,
const PbRpf<std::string>& split_bns, int32_t split_axis) {
......
......@@ -377,6 +377,7 @@ class FieldIterator {
virtual char* GetMutPtr(Blob* blob) = 0;
virtual size_t GetSizeOfField(Blob* blob) const = 0;
private:
std::function<Blob*(const std::string&)> BnInOp2Blob_;
const PbRpf<std::string>* bns_;
int32_t bn_idx_;
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment