Skip to content
Snippets Groups Projects
Commit 15c7d9c1 authored by Joejiong's avatar Joejiong
Browse files

modify to new memref

parent c3ee52d8
No related branches found
No related tags found
No related merge requests found
...@@ -33,17 +33,13 @@ namespace ...@@ -33,17 +33,13 @@ namespace
} }
intptr_t sizesInput[4] = {1, 2, 3, 3}; intptr_t sizesInput[4] = {1, 2, 3, 3};
intptr_t stridesInput[4] = {18, 9, 3, 1};
intptr_t sizesFilter[4] = {2, 2, 2, 2}; intptr_t sizesFilter[4] = {2, 2, 2, 2};
intptr_t stridesFilter[4] = {8, 4, 2, 1};
intptr_t sizesOutput[4] = {1, 2, 2, 2}; intptr_t sizesOutput[4] = {1, 2, 2, 2};
intptr_t stridesOutput[4] = {8, 4, 2, 1};
// Create input, filter, and output. // Create input, filter, and output.
MemRef<float, 4> inputMemRef(2.0, sizesInput, stridesInput); MemRef<float, 4> inputMemRef(sizesInput, 2.0);
MemRef<float, 4> filterMemRef(3.0, sizesFilter, stridesFilter); MemRef<float, 4> filterMemRef(sizesFilter, 3.0);
MemRef<float, 4> outputMemRef(0.0, sizesOutput, stridesOutput); MemRef<float, 4> outputMemRef(sizesOutput, 0.0);
// Define benchmark function. // Define benchmark function.
void BM_Conv2DNchwFchw(benchmark::State &state) void BM_Conv2DNchwFchw(benchmark::State &state)
...@@ -68,13 +64,13 @@ BENCHMARK(BM_Conv2DNchwFchw)->Arg(4); ...@@ -68,13 +64,13 @@ BENCHMARK(BM_Conv2DNchwFchw)->Arg(4);
void printResult() void printResult()
{ {
// Clear the output memref. // Clear the output memref.
MemRef<float, 4> outputMemRef(0.0, sizesOutput, stridesOutput); MemRef<float, 4> outputMemRef(sizesOutput, 0.0);
// Run the mlir function. // Run the mlir function.
_mlir_ciface_conv_2d_nchw_fchw(&inputMemRef, &filterMemRef, _mlir_ciface_conv_2d_nchw_fchw(&inputMemRef, &filterMemRef,
&outputMemRef); &outputMemRef);
// Print the output. // Print the output.
std::cout << "Output: [ "; std::cout << "Output: [ ";
for (int i = 0; i < 8; ++i) for (int i = 0; i < 8; ++i)
std::cout << outputMemRef.aligned[i] << " "; std::cout << outputMemRef[i] << " ";
std::cout << "]" << std::endl; std::cout << "]" << std::endl;
} }
...@@ -33,17 +33,13 @@ namespace ...@@ -33,17 +33,13 @@ namespace
} }
intptr_t sizesInput[4] = {1, 3, 3, 2}; intptr_t sizesInput[4] = {1, 3, 3, 2};
intptr_t stridesInput[4] = {18, 6, 2, 1};
intptr_t sizesFilter[4] = {2, 2, 2, 2}; intptr_t sizesFilter[4] = {2, 2, 2, 2};
intptr_t stridesFilter[4] = {8, 4, 2, 1};
intptr_t sizesOutput[4] = {1, 2, 2, 2}; intptr_t sizesOutput[4] = {1, 2, 2, 2};
intptr_t stridesOutput[4] = {8, 4, 2, 1};
// Create input, filter, and output. // Create input, filter, and output.
MemRef<float, 4> inputMemRef(2.0, sizesInput, stridesInput); MemRef<float, 4> inputMemRef(sizesInput, 2.0);
MemRef<float, 4> filterMemRef(3.0, sizesFilter, stridesFilter); MemRef<float, 4> filterMemRef(sizesFilter, 3.0);
MemRef<float, 4> outputMemRef(0.0, sizesOutput, stridesOutput); MemRef<float, 4> outputMemRef(sizesOutput, 0.0);
// Define benchmark function. // Define benchmark function.
void BM_Conv2DNhwcHwcf(benchmark::State &state) void BM_Conv2DNhwcHwcf(benchmark::State &state)
...@@ -68,13 +64,13 @@ BENCHMARK(BM_Conv2DNhwcHwcf)->Arg(4); ...@@ -68,13 +64,13 @@ BENCHMARK(BM_Conv2DNhwcHwcf)->Arg(4);
void printResult() void printResult()
{ {
// Clear the output memref. // Clear the output memref.
MemRef<float, 4> outputMemRef(0.0, sizesOutput, stridesOutput); MemRef<float, 4> outputMemRef(sizesOutput, 0.0);
// Run the mlir function. // Run the mlir function.
_mlir_ciface_conv_2d_nhwc_hwcf(&inputMemRef, &filterMemRef, _mlir_ciface_conv_2d_nhwc_hwcf(&inputMemRef, &filterMemRef,
&outputMemRef); &outputMemRef);
// Print the output. // Print the output.
std::cout << "Output: [ "; std::cout << "Output: [ ";
for (int i = 0; i < 8; ++i) for (int i = 0; i < 8; ++i)
std::cout << outputMemRef.aligned[i] << " "; std::cout << outputMemRef[i] << " ";
std::cout << "]" << std::endl; std::cout << "]" << std::endl;
} }
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment