Commit a610d506 authored by 叶剑武's avatar 叶剑武
Browse files

Merge branch 'lp_normalization' into 'master'

add lpnorm、mvnorm op for caffe, enhance biasadd、reshape op

See merge request !1224
parents 8a0abcac 72a3751f
Showing with 1212 additions and 58 deletions
+1212 -58
......@@ -90,7 +90,9 @@ MemoryBlock MemoryOptimizer::CreateMemoryBlock(
if (shape.size() == 2) {
shape = {shape[0], 1, 1, shape[1]};
} else {
MACE_CHECK(shape.size() == 4) << "GPU only support 2D/4D input";
MACE_CHECK(shape.size() == 4) << "GPU only support 2D/4D input, "
<< "op name: " << op_def->name() << ", "
<< MakeString(shape);
}
OpenCLUtil::CalImage2DShape(shape, buffer_type, &image_shape);
block.set_x(image_shape[0]);
......
......@@ -62,34 +62,62 @@ void BiasAdd::AddBias(const OpContext *context,
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t b = start0; b < end0; b += step0) {
for (index_t c = start1; c < end1; c += step1) {
const index_t offset = (b * channels + c) * image_size;
auto input_ptr = input_data + offset;
auto output_ptr = output_data + offset;
const float bias = bias_data[c];
float32x4_t vbias = vdupq_n_f32(bias);
if (bias->dim_size() == 1) {
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t b = start0; b < end0; b += step0) {
const index_t b_offset = b * channels;
for (index_t c = start1; c < end1; c += step1) {
const index_t offset = (b_offset + c) * image_size;
auto input_ptr = input_data + offset;
auto output_ptr = output_data + offset;
const float bias = bias_data[c];
float32x4_t vbias = vdupq_n_f32(bias);
for (index_t i = 0; i < block_count; ++i) {
float32x4_t v = vld1q_f32(input_ptr);
v = vaddq_f32(v, vbias);
vst1q_f32(output_ptr, v);
for (index_t i = 0; i < block_count; ++i) {
float32x4_t v = vld1q_f32(input_ptr);
v = vaddq_f32(v, vbias);
vst1q_f32(output_ptr, v);
input_ptr += 4;
output_ptr += 4;
input_ptr += 4;
output_ptr += 4;
}
for (index_t i = 0; i < remain; ++i) {
(*output_ptr++) = (*input_ptr++) + bias;
}
}
for (index_t i = 0; i < remain; ++i) {
(*output_ptr++) = (*input_ptr++) + bias;
}
}, 0, batch, 1, 0, channels, 1);
} else {
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t b = start0; b < end0; b += step0) {
const index_t b_offset = b * channels;
for (index_t c = start1; c < end1; c += step1) {
const index_t offset = (b_offset + c) * image_size;
auto input_ptr = input_data + offset;
auto output_ptr = output_data + offset;
const float bias = bias_data[b * channels + c];
float32x4_t vbias = vdupq_n_f32(bias);
for (index_t i = 0; i < block_count; ++i) {
float32x4_t v = vld1q_f32(input_ptr);
v = vaddq_f32(v, vbias);
vst1q_f32(output_ptr, v);
input_ptr += 4;
output_ptr += 4;
}
for (index_t i = 0; i < remain; ++i) {
(*output_ptr++) = (*input_ptr++) + bias;
}
}
}
}
}, 0, batch, 1, 0, channels, 1);
}, 0, batch, 1, 0, channels, 1);
}
}
} // namespace fp32
} // namespace arm
} // namespace ops
} // namespace mace
......@@ -49,15 +49,18 @@ class BiasAddOp<DeviceType::CPU, float> : public Operation {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
const Tensor *bias = this->Input(1);
MACE_CHECK(bias->dim_size() == 1, "bias must be 1-dimensional. ",
bias->dim_size());
Tensor *output = this->Output(0);
if (input->dim_size() == 4 && has_data_format_) {
if (input->dim_size() == 4 && (has_data_format_
|| input->data_format() == DataFormat::NCHW)) { // NCHW
MACE_CHECK(bias->dim_size() == 1 || bias->dim_size() == 2,
"bias must be 1-dimensional or n*c for caffee.",
MakeString(bias->shape()));
bias_add_delegator_.Compute(context, input, bias, output);
} else {
} else { // NHWC
MACE_CHECK(bias->dim_size() == 1 || bias->dim_size() == 2,
"bias must be 1 or 2 dimensionals for caffee.",
bias->dim_size(), MakeString(bias->shape()));
// TODO(liyin): remove it and tranform bias to add (eltwise)
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
......@@ -70,16 +73,40 @@ class BiasAddOp<DeviceType::CPU, float> : public Operation {
float *output_ptr = output->mutable_data<float>();
const std::vector<index_t> &shape = input->shape();
const index_t fused_batch = std::accumulate(
shape.begin(), shape.end() - 1, 1, std::multiplies<index_t>());
const index_t channels = *shape.rbegin();
for (index_t n = 0; n < fused_batch; ++n) {
index_t pos = n * channels;
for (index_t c = 0; c < channels; ++c) {
output_ptr[pos] = input_ptr[pos] + bias_ptr[c];
++pos;
}
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
if (bias->dim_size() == 1) {
const index_t fused_batch = std::accumulate(
shape.begin(), shape.end() - 1, 1, std::multiplies<index_t>());
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t n = start; n < end; n += step) {
index_t pos = n * channels;
for (index_t c = 0; c < channels; ++c) {
output_ptr[pos] = input_ptr[pos] + bias_ptr[c];
++pos;
}
}
}, 0, fused_batch, 1);
} else { // bias is 2d
const auto n = shape[0];
MACE_CHECK(n == bias->shape()[0]);
const index_t fused_hw = std::accumulate(
shape.begin() + 1, shape.end() - 1, 1, std::multiplies<index_t>());
const auto ch_size = bias->shape()[1];
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t i = start0; i < end0; i += step0) {
auto offset = i * fused_hw;
auto bias_offset = i * ch_size;
for (index_t j = start1; j < end1; j += step1) {
index_t pos = (offset + i) * channels;
for (index_t c = 0; c < channels; ++c, ++pos) {
output_ptr[pos] = input_ptr[pos] + bias_ptr[bias_offset + c];
}
}
}
}, 0, n, 1, 0, fused_hw, 1);
}
}
......@@ -109,21 +136,25 @@ class BiasAddOp<DeviceType::GPU, float> : public Operation {
} else {
MACE_NOT_IMPLEMENTED;
}
MACE_CHECK(TransformFilter(
context, operator_def_.get(), 1, OpenCLBufferType::ARGUMENT, mem_type)
== MaceStatus::MACE_SUCCESS);
// for const bias tensor
if (context->workspace()->GetTensor(operator_def_->input(1)) != nullptr) {
MACE_CHECK(TransformFilter(
context, operator_def_.get(), 1, OpenCLBufferType::ARGUMENT, mem_type)
== MaceStatus::MACE_SUCCESS, "TransformFilter failed");
}
}
MaceStatus Run(OpContext *context) override {
const Tensor *input = this->Input(0);
const Tensor *bias = this->Input(1);
MACE_CHECK(bias->dim_size() == 1, "bias must be 1-dimensional. ",
bias->dim_size());
Tensor *output = this->Output(0);
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
MACE_CHECK(input->dim_size() == 4 && has_data_format_,
"gpu only support biasadd for 4-dimensional NHWC format tensor");
MACE_CHECK(bias->dim_size() == 1 || bias->dim_size() == 2,
"bias must be 1-dimensional or 2-dimensional for caffee. ",
MakeString(bias->shape()));
return kernel_->Compute(context, input, bias, output);
}
......@@ -151,6 +182,10 @@ void RegisterBiasAdd(OpRegistryBase *op_registry) {
*op, "has_data_format", 0);
if (!has_data_format ||
op->output_shape(0).dims_size() != 4) {
LOG(INFO) << "BiasAdd only support cpu, has_data_format="
<< has_data_format
<< ", op->output_shape(0).dims_size()="
<< op->output_shape(0).dims_size();
return {DeviceType::CPU};
}
return {DeviceType::CPU, DeviceType::GPU};
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <functional>
#include <memory>
#include "mace/core/operator.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/lpnorm.h"
#endif // MACE_ENABLE_OPENCL
/**
* LpNormOp is a Normalization OP which support L1 and L2, which is a custom op
* of caffe (not exist in official caffe), please reference:
* https://github.com/freesouls/caffe/blob/master/src/caffe/layers/normalization_layer.cpp #noqa
*/
namespace mace {
namespace ops {
template<DeviceType D, typename T>
class LpNormOp;
template<>
class LpNormOp<DeviceType::CPU, float> : public Operation {
public:
explicit LpNormOp(OpConstructContext *context)
: Operation(context),
p_(Operation::GetOptionalArg<int>("p", 2)),
axis_(Operation::GetOptionalArg<int>("axis", -1)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
if (axis_ < 0) {
axis_ += input->dim_size();
}
MACE_CHECK(axis_ < input->dim_size() && axis_ >= 0,
"The axis_ must be small than dim size");
const std::vector<index_t> &input_shape = input->shape();
MACE_RETURN_IF_ERROR(output->Resize(input_shape));
Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_output(output);
const auto *input_data = input->data<float>();
auto *output_data = output->mutable_data<float>();
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
auto outer_loop = std::accumulate(input_shape.begin(),
input_shape.begin() + axis_, 1,
std::multiplies<index_t>());
auto inner_loop = std::accumulate(input_shape.begin() + axis_,
input_shape.end(), 1,
std::multiplies<index_t>());
if (p_ == 1) {
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
output_data[i] = std::abs(input_data[i]);
}
}, 0, input->size(), 1);
} else if (p_ == 2) {
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
output_data[i] = input_data[i] * input_data[i];
}
}, 0, input->size(), 1);
} else {
LOG(FATAL) << "LpNorm's p should be 1 or 2, current p is: " << p_;
}
const float power = 1 / static_cast<float>(p_);
auto norm_buffer = context->device()->scratch_buffer();
norm_buffer->Rewind();
MACE_RETURN_IF_ERROR(norm_buffer->GrowSize(outer_loop * sizeof(float)));
float *norm_ptr = norm_buffer->mutable_data<float>();
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
auto output_data_base = output_data + inner_loop * i;
norm_ptr[i] = std::accumulate(output_data_base,
output_data_base + inner_loop, 0.0f);
norm_ptr[i] = std::pow(norm_ptr[i], power);
norm_ptr[i] += 1e-6;
}
}, 0, outer_loop, 1);
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t i = start0; i < end0; i += step0) {
const auto offset = i * inner_loop;
for (index_t j = start1; j < end1; j += step1) {
output_data[offset + j] = input_data[offset + j] / norm_ptr[i];
}
}
}, 0, outer_loop, 1, 0, inner_loop, 1);
return MaceStatus::MACE_SUCCESS;
}
private:
int p_;
int axis_;
};
#ifdef MACE_ENABLE_OPENCL
template<>
class LpNormOp<DeviceType::GPU, float> : public Operation {
public:
explicit LpNormOp(OpConstructContext *context)
: Operation(context) {
const auto p = Operation::GetOptionalArg<int>("p", 2);
const auto axis = Operation::GetOptionalArg<int>("axis", -1);
if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::LpNormKernel>(p, axis);
} else {
MACE_NOT_IMPLEMENTED;
}
}
MaceStatus Run(OpContext *context) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
return kernel_->Compute(context, input, output);
}
private:
std::unique_ptr<OpenCLLpNormKernel> kernel_;
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
#endif // MACE_ENABLE_OPENCL
void RegisterLpNorm(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "LpNorm", LpNormOp,
DeviceType::CPU, float);
MACE_REGISTER_GPU_OP(op_registry, "LpNorm", LpNormOp);
}
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <functional>
#include <memory>
#include "mace/core/operator.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/mvnorm.h"
#endif // MACE_ENABLE_OPENCL
namespace mace {
namespace ops {
// Mean-Variance Normalization (MVN)
template<DeviceType D, typename T>
class MVNormOp;
template<>
class MVNormOp<DeviceType::CPU, float> : public Operation {
public:
explicit MVNormOp(OpConstructContext *context)
: Operation(context),
normalize_variance_(
Operation::GetOptionalArg<bool>("normalize_variance", true)),
across_channels_(
Operation::GetOptionalArg<bool>("across_channels", false)),
eps_(Operation::GetOptionalArg<float>("epsilon", 1e-9)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
MACE_CHECK(input->data_format() == DataFormat::NCHW,
"The MVN only suport NCHW");
const std::vector<index_t> &input_shape = input->shape();
MACE_RETURN_IF_ERROR(output->Resize(input_shape));
Tensor::MappingGuard guard_input(input);
Tensor::MappingGuard guard_output(output);
const auto *input_data = input->data<float>();
auto *output_data = output->mutable_data<float>();
const auto input_size = input->size();
const auto outer_loop =
across_channels_ ? input_shape[0] : input_shape[0] * input_shape[1];
const auto inner_loop = input_size / outer_loop;
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
Buffer mean_buffer(context->device()->allocator());
MACE_RETURN_IF_ERROR(mean_buffer.Allocate(outer_loop * sizeof(float)));
auto *mean_ptr = mean_buffer.mutable_data<float>();
// compute EX
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
const auto offset = inner_loop * i;
mean_ptr[i] = std::accumulate(input_data + offset,
input_data + offset + inner_loop, 0.0f);
mean_ptr[i] /= inner_loop;
}
}, 0, outer_loop, 1);
// compute (X - EX)
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t i = start0; i < end0; i += step0) {
const auto offset = i * inner_loop;
for (index_t j = start1; j < end1; j += step1) {
output_data[offset + j] = input_data[offset + j] - mean_ptr[i];
}
}
}, 0, outer_loop, 1, 0, inner_loop, 1);
if (normalize_variance_) {
// compute (X - EX)^2
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
output_data[i] = output_data[i] * output_data[i];
}
}, 0, input_size, 1);
auto mean_v_buffer = context->device()->scratch_buffer();
mean_v_buffer->Rewind();
MACE_RETURN_IF_ERROR(
mean_v_buffer->GrowSize(outer_loop * sizeof(float)));
float *mean_v_ptr = mean_v_buffer->mutable_data<float>();
// compute E((X - EX)^2)^0.5 + eps_
thread_pool.Compute1D([=](index_t start, index_t end, index_t step) {
for (index_t i = start; i < end; i += step) {
auto output_data_base = output_data + inner_loop * i;
mean_v_ptr[i] = std::accumulate(output_data_base,
output_data_base + inner_loop, 0.0f);
mean_v_ptr[i] = std::pow(mean_v_ptr[i] / inner_loop, 0.5f) + eps_;
}
}, 0, outer_loop, 1);
// compute (X - EX) / (E((X - EX)^2)^0.5 + eps_)
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t i = start0; i < end0; i += step0) {
const auto offset = i * inner_loop;
for (index_t j = start1; j < end1; j += step1) {
output_data[offset + j] =
(input_data[offset + j] - mean_ptr[i]) / mean_v_ptr[i];
}
}
}, 0, outer_loop, 1, 0, inner_loop, 1);
}
return MaceStatus::MACE_SUCCESS;
}
private:
bool normalize_variance_;
bool across_channels_;
float eps_;
};
#ifdef MACE_ENABLE_OPENCL
template<>
class MVNormOp<DeviceType::GPU, float> : public Operation {
public:
explicit MVNormOp(OpConstructContext *context) : Operation(context) {
auto normalize_variance =
Operation::GetOptionalArg<bool>("normalize_variance", true);
auto across_channels =
Operation::GetOptionalArg<bool>("across_channels", false);
auto eps = Operation::GetOptionalArg<float>("epsilon", 1e-9);
if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::MVNormKernel>(
normalize_variance, across_channels, eps);
} else {
MACE_NOT_IMPLEMENTED;
}
}
MaceStatus Run(OpContext *context) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
MACE_RETURN_IF_ERROR(output->ResizeLike(input));
return kernel_->Compute(context, input, output);
}
private:
std::unique_ptr<OpenCLMVNormKernel> kernel_;
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
#endif // MACE_ENABLE_OPENCL
void RegisterMVNorm(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "MVNorm", MVNormOp,
DeviceType::CPU, float);
MACE_REGISTER_GPU_OP(op_registry, "MVNorm", MVNormOp);
}
} // namespace ops
} // namespace mace
......@@ -2,25 +2,27 @@
// Supported data types: half/float
__kernel void bias_add(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__private const int input_height,
__read_only image2d_t input,
__read_only image2d_t bias,
__write_only image2d_t output) {
const int ch_blk = get_global_id(0);
const int w = get_global_id(1);
const int hb = get_global_id(2);
const int width_idx = get_global_id(1);
const int hb_idx = get_global_id(2);
#ifndef NON_UNIFORM_WORK_GROUP
if (ch_blk >= global_size_dim0 || w >= global_size_dim1
|| hb >= global_size_dim2) {
if (ch_blk >= global_size_dim0 || width_idx >= global_size_dim1
|| hb_idx >= global_size_dim2) {
return;
}
#endif
const int width = global_size_dim1;
const int pos = mad24(ch_blk, width, w);
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb));
DATA_TYPE4 bias_value = READ_IMAGET(bias, SAMPLER, (int2)(ch_blk, 0));
const int pos = mad24(ch_blk, width, width_idx);
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx));
const int b_idx = select(0, hb_idx / input_height, input_height > 0);
DATA_TYPE4 bias_value = READ_IMAGET(bias, SAMPLER, (int2)(ch_blk, b_idx));
DATA_TYPE4 out = in + bias_value;
WRITE_IMAGET(output, (int2)(pos, hb), out);
WRITE_IMAGET(output, (int2)(pos, hb_idx), out);
}
#include <common.h>
DATA_TYPE4 compute_total(__read_only image2d_t input, const int hb_base,
const int chan_blks, const int width, const int height,
const int hb_idx, const int chan_blk_idx) {
DATA_TYPE4 total = 0.0f;
#if PARAM_AXIS == 1
const int wc_blks = mul24(width, chan_blks);
for (int h_idx = hb_base; h_idx < hb_base + height; ++h_idx) {
for (int pos = 0; pos < wc_blks; ++pos) {
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, h_idx));
#if PARAM_P == 1
total += fabs(in_data);
#else
total = mad(in_data, in_data, total);
#endif
}
}
DATA_TYPE total_all = total.x + total.y + total.z + total.w;
total = (DATA_TYPE4){total_all, total_all, total_all, total_all};
#elif PARAM_AXIS == 2
for (int h_idx = hb_base; h_idx < hb_base + height; ++h_idx) {
for (int w_idx = 0; w_idx < width; ++w_idx) {
int pos = mad24(chan_blk_idx, width, w_idx);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, h_idx));
#if PARAM_P == 1
total = total + fabs(in_data);
#else
total = mad(in_data, in_data, total);
#endif
}
}
#elif PARAM_AXIS == 3
for (int w_idx = 0; w_idx < width; ++x) {
int pos = mad24(chan_blk_idx, width, w_idx);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx));
#if PARAM_P == 1
total = total + fabs(in_data);
#else
total = mad(in_data, in_data, total);
#endif
}
#endif
return total;
}
__kernel void lpnorm(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__private const int height,
__private const float eps,
__write_only image2d_t output) {
const int chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
const int hb_idx = get_global_id(2);
#ifndef NON_UNIFORM_WORK_GROUP
if (chan_blk_idx >= global_size_dim0 || width_idx >= global_size_dim1
|| hb_idx >= global_size_dim2) {
return;
}
#endif
const int chan_blks = global_size_dim0;
const int width = global_size_dim1;
const int hb = global_size_dim2;
const int hb_base = mul24(hb_idx / height, height);
DATA_TYPE4 total = compute_total(input, hb_base, chan_blks, width, height,
hb_idx, chan_blk_idx);
const int pos = mad24(chan_blk_idx, width, width_idx);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx));
#if PARAM_P == 1
in_data = in_data / (total + eps);
#else
in_data = in_data / (sqrt(total) + eps);
#endif
WRITE_IMAGET(output, (int2)(pos, hb_idx), in_data);
}
#include <common.h>
DATA_TYPE4 compute_mean_image(image2d_t input, const int width_idx,
const int hb_idx, const int chan_blks,
const int height, const int width) {
DATA_TYPE4 total = 0.0f;
DATA_TYPE4 mean = 0.0f;
const int hb_base = mul24(hb_idx / height, height);
const int wc_blks = mul24(width, chan_blks);
#ifdef ACROSS_CHANNELS
for (int h_idx = hb_base; h_idx < hb_base + height; ++h_idx) {
for (int pos = 0; pos < wc_blks; ++pos) {
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, h_idx));
total += in_data;
}
}
DATA_TYPE total_value = total.x + total.y + total.z + total.w;
DATA_TYPE mean_value = total_value / (DATA_TYPE)(mul24(mul24(height, wc_blks), 4));
mean = (DATA_TYPE4){mean_value, mean_value, mean_value, mean_value};
#else
for (int h_idx = hb_base; h_idx < hb_base + height; ++h_idx) {
for (int w_idx = 0; w_idx < width; ++w_idx) {
int pos = mad24(w_idx, chan_blks, width_idx);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, h_idx));
total += in_data;
}
}
mean = total / mul24(height, width);
#endif
return mean;
}
__kernel void mvnorm_mean(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__private const int height,
__write_only image2d_t output) {
const int chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
const int hb_idx = get_global_id(2);
#ifndef NON_UNIFORM_WORK_GROUP
if (chan_blk_idx >= global_size_dim0 || width_idx >= global_size_dim1
|| hb_idx >= global_size_dim2) {
return;
}
#endif
const int chan_blks = global_size_dim0;
const int width = global_size_dim1;
DATA_TYPE4 mean = compute_mean_image(input, width_idx,
hb_idx, chan_blks, height, width);
const int pos = mad24(chan_blk_idx, width, width_idx);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx));
in_data -= mean;
WRITE_IMAGET(output, (int2)(pos, hb_idx), in_data);
}
__kernel void mvnorm_vn_step1(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__write_only image2d_t mean_image, // E(X)
__write_only image2d_t square_image, // (X - EX)^2
__private const int height) {
const int chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
const int hb_idx = get_global_id(2);
#ifndef NON_UNIFORM_WORK_GROUP
if (chan_blk_idx >= global_size_dim0 || width_idx >= global_size_dim1
|| hb_idx >= global_size_dim2) {
return;
}
#endif
const int chan_blks = global_size_dim0;
const int width = global_size_dim1;
DATA_TYPE4 mean = compute_mean_image(input, width_idx,
hb_idx, chan_blks, height, width);
const int pos = mad24(chan_blk_idx, width, width_idx);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx));
in_data = in_data - mean;
DATA_TYPE4 pow_data = in_data * in_data;
if (hb_idx == 0 && width_idx == 0) {
WRITE_IMAGET(mean_image, (int2)(chan_blk_idx, 0), mean);
}
WRITE_IMAGET(square_image, (int2)(pos, hb_idx), pow_data);
}
__kernel void mvnorm_vn_step2(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__read_only image2d_t mean_image, // E(X)
__read_only image2d_t square_image, // (X - EX)^2
__private const int height,
__private const float eps,
__write_only image2d_t output) {
const int chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
const int hb_idx = get_global_id(2);
#ifndef NON_UNIFORM_WORK_GROUP
if (chan_blk_idx >= global_size_dim0 || width_idx >= global_size_dim1
|| hb_idx >= global_size_dim2) {
return;
}
#endif
const int chan_blks = global_size_dim0;
const int width = global_size_dim1;
DATA_TYPE4 mean = READ_IMAGET(mean_image, SAMPLER, (int2)(chan_blk_idx, 0));
const int pos = mad24(chan_blk_idx, width, width_idx);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(pos, hb_idx));
in_data = in_data - mean;
DATA_TYPE4 mean_v = compute_mean_image(square_image, width_idx,
hb_idx, chan_blks, height, width);
DATA_TYPE4 norm_data = in_data / (sqrt(mean_v) + eps);
WRITE_IMAGET(output, (int2)(pos, hb_idx), norm_data);
}
......@@ -56,6 +56,7 @@ MaceStatus BiasAddKernel::Compute(
uint32_t idx = 0;
MACE_OUT_OF_RANGE_SET_ARGS(kernel_);
MACE_SET_3D_GWS_ARGS(kernel_, gws);
kernel_.setArg(idx++, static_cast<int>(bias->dim_size() > 1 ? height : 0));
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, *(bias->opencl_image()));
kernel_.setArg(idx++, *(output->opencl_image()));
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/opencl/image/lpnorm.h"
#include <set>
#include <string>
#include <vector>
namespace mace {
namespace ops {
namespace opencl {
namespace image {
LpNormKernel::LpNormKernel(const int p, const int axis) : p_(p), axis_(axis) {
MACE_CHECK(p_ == 1 || p_ == 2, "Current p is: ", p);
}
MaceStatus LpNormKernel::Compute(OpContext *context,
const Tensor *input, Tensor *output) {
if (axis_ < 0) {
axis_ += input->dim_size();
}
MACE_CHECK(axis_ == 1 || axis_ == 2 || axis_ == 3,
"Current axis is: ", axis_);
const auto batch = input->dim(0);
const auto height = input->dim(1);
const auto width = input->dim(2);
const auto channels = input->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)};
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
MACE_OUT_OF_RANGE_CONFIG;
MACE_NON_UNIFORM_WG_CONFIG;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("lpnorm");
built_options.emplace("-Dlpnorm=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DT_FLOAT));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(DT_FLOAT));
std::stringstream param_p;
param_p << "-DPARAM_P=" << p_;
built_options.emplace(param_p.str());
std::stringstream param_axis;
param_axis << "-DPARAM_AXIS=" << axis_;
built_options.emplace(param_axis.str());
MACE_RETURN_IF_ERROR(runtime->BuildKernel("lpnorm", kernel_name,
built_options, &kernel_));
kwg_size_ =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
}
MACE_OUT_OF_RANGE_INIT(kernel_);
uint32_t idx = 0;
MACE_OUT_OF_RANGE_SET_ARGS(kernel_);
MACE_SET_3D_GWS_ARGS(kernel_, gws);
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, static_cast<int>(height));
kernel_.setArg(idx++, static_cast<float>(1e-6));
kernel_.setArg(idx++, *(output->opencl_image()));
std::vector<uint32_t> lws = Default3DLocalWS(runtime, gws, kwg_size_);
std::string tuning_key =
Concat("lpnorm_opencl_kernel", batch, height, width, channels, p_, axis_);
MACE_RETURN_IF_ERROR(TuningOrRun3DKernel(runtime, kernel_, tuning_key,
gws, lws, context->future()));
MACE_OUT_OF_RANGE_VALIDATION;
return MaceStatus::MACE_SUCCESS;
}
} // namespace image
} // namespace opencl
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_OPENCL_IMAGE_LPNORM_H_
#define MACE_OPS_OPENCL_IMAGE_LPNORM_H_
#include "mace/core/op_context.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/tensor.h"
#include "mace/ops/opencl/helper.h"
#include "mace/ops/opencl/lpnorm.h"
namespace mace {
namespace ops {
namespace opencl {
namespace image {
class LpNormKernel : public OpenCLLpNormKernel {
public:
explicit LpNormKernel(const int p, const int axis);
~LpNormKernel() = default;
MaceStatus Compute(
OpContext *context, const Tensor *input, Tensor *output) override;
private:
int p_;
int axis_;
cl::Kernel kernel_;
uint32_t kwg_size_;
};
} // namespace image
} // namespace opencl
} // namespace ops
} // namespace mace
#endif // MACE_OPS_OPENCL_IMAGE_LPNORM_H_
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/opencl/image/mvnorm.h"
#include <memory>
#include <set>
#include <string>
#include <vector>
namespace mace {
namespace ops {
namespace opencl {
namespace image {
namespace {
MaceStatus BuildMVNKernel(OpenCLRuntime *runtime, cl::Kernel *kernel,
const char *kernel_name,
std::set<std::string> *built_options,
bool across_channel) {
std::stringstream micro_name;
micro_name << "-Dmvnorm=" << kernel_name;
built_options->emplace(micro_name.str());
built_options->emplace("-DDATA_TYPE=" + DtToCLDt(DT_FLOAT));
built_options->emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(DT_FLOAT));
if (across_channel) {
built_options->emplace("-DACROSS_CHANNELS");
}
MACE_RETURN_IF_ERROR(runtime->BuildKernel("mvnorm", kernel_name,
*built_options, kernel));
return MaceStatus::MACE_SUCCESS;
}
std::unique_ptr<Image> CreateImage(
OpContext *context, const DataType dt,
const std::vector<index_t> &buffer_shape) {
std::unique_ptr<Image> image =
make_unique<Image>(context->device()->allocator());
std::vector<size_t> shape;
OpenCLUtil::CalImage2DShape(
buffer_shape, OpenCLBufferType::IN_OUT_CHANNEL, &shape);
MACE_CHECK(image->Allocate(shape, dt) == MaceStatus::MACE_SUCCESS);
VLOG(1) << "MVNormKernel::CreateImage allocate image_:" << MakeString(shape);
return image;
}
} // namespace
MVNormKernel::MVNormKernel(bool normalize_variance,
bool across_channels, float eps)
: normalize_variance_(normalize_variance),
across_channels_(across_channels),
eps_(eps) {}
void MVNormKernel::CheckImage(OpContext *context, const DataType dt,
const std::vector<index_t> &square_shape,
const std::vector<index_t> &mean_shape) {
if (square_image_ == nullptr) {
square_image_ = CreateImage(context, dt, square_shape);
}
if (mean_image_ == nullptr) {
mean_image_ = CreateImage(context, dt, mean_shape);
}
}
MaceStatus MVNormKernel::Compute(OpContext
*context,
const Tensor *input, Tensor
*output) {
const auto batch = input->dim(0);
const auto height = input->dim(1);
const auto width = input->dim(2);
const auto channels = input->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)};
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
if (normalize_variance_) {
const std::vector<index_t> &square_shape = input->buffer_shape();
const std::vector<index_t> mean_shape = {1, 1, 1, channels};
CheckImage(context, input->dtype(), square_shape, mean_shape);
// compute the (X - EX)^2
MACE_RETURN_IF_ERROR(ExecuteVarianceNormStep1Kernel(
context, runtime, gws, input));
// compute the compute (X - EX) / (E((X - EX)^2)^0.5 + eps_)
MACE_RETURN_IF_ERROR(ExecuteVarianceNormStep2Kernel(
context, runtime, gws, input, output));
} else {
MACE_RETURN_IF_ERROR(ExecuteMeanNormKernel(
context, runtime, gws, input, output));
}
return
MaceStatus::MACE_SUCCESS;
}
MaceStatus MVNormKernel::ExecuteMeanNormKernel(OpContext *context,
OpenCLRuntime *runtime,
const uint32_t (&gws)[3],
const Tensor *input,
Tensor *output) {
const auto height = input->dim(1);
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_step1_.get() == nullptr) {
std::set<std::string> built_options;
MACE_OUT_OF_RANGE_CONFIG;
MACE_NON_UNIFORM_WG_CONFIG;
MACE_RETURN_IF_ERROR(BuildMVNKernel(runtime, &kernel_step1_, "mvnorm_mean",
&built_options, across_channels_));
kwg_size_step1_ = static_cast<uint32_t>(
runtime->GetKernelMaxWorkGroupSize(kernel_step1_));
}
MACE_OUT_OF_RANGE_INIT(kernel_step1_);
uint32_t idx = 0;
MACE_OUT_OF_RANGE_SET_ARGS(kernel_step1_);
MACE_SET_3D_GWS_ARGS(kernel_step1_, gws);
kernel_step1_.setArg(idx++, *(input->opencl_image()));
kernel_step1_.setArg(idx++, static_cast<int>(height));
kernel_step1_.setArg(idx++, *(output->opencl_image()));
std::vector<uint32_t> lws = Default3DLocalWS(runtime, gws, kwg_size_step1_);
std::string
tuning_key = Concat("mvnorm_mean_opencl_kernel", gws[0], gws[1], gws[2],
normalize_variance_, across_channels_);
MACE_RETURN_IF_ERROR(TuningOrRun3DKernel(runtime, kernel_step1_, tuning_key,
gws, lws, context->future()));
MACE_OUT_OF_RANGE_VALIDATION;
return MaceStatus::MACE_SUCCESS;
}
// The first step of compute Variance Norm, compute the (X - EX)^2
// store them into the square_image_
MaceStatus MVNormKernel::ExecuteVarianceNormStep1Kernel(
OpContext *context, OpenCLRuntime *runtime,
const uint32_t (&gws)[3], const Tensor *input) {
const auto height = input->dim(1);
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_step1_.get() == nullptr) {
std::set<std::string> built_options;
MACE_OUT_OF_RANGE_CONFIG;
MACE_NON_UNIFORM_WG_CONFIG;
MACE_RETURN_IF_ERROR(BuildMVNKernel(runtime, &kernel_step1_,
"mvnorm_vn_step1",
&built_options, across_channels_));
kwg_size_step1_ = static_cast<uint32_t>(
runtime->GetKernelMaxWorkGroupSize(kernel_step1_));
}
MACE_OUT_OF_RANGE_INIT(kernel_step1_);
uint32_t idx = 0;
MACE_OUT_OF_RANGE_SET_ARGS(kernel_step1_);
MACE_SET_3D_GWS_ARGS(kernel_step1_, gws);
kernel_step1_.setArg(idx++, *(input->opencl_image()));
cl::Image *mean_image = static_cast<cl::Image *>(mean_image_->buffer());
kernel_step1_.setArg(idx++, *mean_image);
cl::Image *square_image = static_cast<cl::Image *>(square_image_->buffer());
kernel_step1_.setArg(idx++, *square_image);
kernel_step1_.setArg(idx++, static_cast<int>(height));
std::vector<uint32_t> lws = Default3DLocalWS(runtime, gws, kwg_size_step1_);
std::string
tuning_key = Concat("mvnorm_v_step1_opencl_kernel", gws[0], gws[1],
gws[2], normalize_variance_,
across_channels_);
MACE_RETURN_IF_ERROR(TuningOrRun3DKernel(runtime, kernel_step1_, tuning_key,
gws, lws, context->future()));
MACE_OUT_OF_RANGE_VALIDATION;
return MaceStatus::MACE_SUCCESS;
}
// The second step of compute Variance Norm, read the (X - EX)^2 from
// square_image_ and compute (X - EX) / (E((X - EX)^2)^0.5 + eps_)
MaceStatus MVNormKernel::ExecuteVarianceNormStep2Kernel(
OpContext *context, OpenCLRuntime *runtime, const uint32_t (&gws)[3],
const Tensor *input, Tensor *output) {
const auto height = input->dim(1);
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_step2_.get() == nullptr) {
std::set<std::string> built_options;
MACE_OUT_OF_RANGE_CONFIG;
MACE_NON_UNIFORM_WG_CONFIG;
MACE_RETURN_IF_ERROR(BuildMVNKernel(runtime, &kernel_step2_,
"mvnorm_vn_step2",
&built_options, across_channels_));
kwg_size_step2_ = static_cast<uint32_t>(
runtime->GetKernelMaxWorkGroupSize(kernel_step2_));
}
MACE_OUT_OF_RANGE_INIT(kernel_step2_);
uint32_t idx = 0;
MACE_OUT_OF_RANGE_SET_ARGS(kernel_step2_);
MACE_SET_3D_GWS_ARGS(kernel_step2_, gws);
kernel_step2_.setArg(idx++, *(input->opencl_image()));
cl::Image *mean_image = static_cast<cl::Image *>(mean_image_->buffer());
kernel_step2_.setArg(idx++, *mean_image);
cl::Image *square_image = static_cast<cl::Image *>(square_image_->buffer());
kernel_step2_.setArg(idx++, *square_image);
kernel_step2_.setArg(idx++, static_cast<int>(height));
kernel_step2_.setArg(idx++, static_cast<float>(eps_));
kernel_step2_.setArg(idx++, *(output->opencl_image()));
std::vector<uint32_t> lws = Default3DLocalWS(runtime, gws, kwg_size_step2_);
std::string
tuning_key = Concat("mvnorm_v_step2_opencl_kernel", gws[0], gws[1],
gws[2], normalize_variance_,
across_channels_);
MACE_RETURN_IF_ERROR(TuningOrRun3DKernel(runtime, kernel_step2_, tuning_key,
gws, lws, context->future()));
MACE_OUT_OF_RANGE_VALIDATION;
return MaceStatus::MACE_SUCCESS;
}
} // namespace image
} // namespace opencl
} // namespace ops
} // namespace mace
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_OPENCL_IMAGE_MVNORM_H_
#define MACE_OPS_OPENCL_IMAGE_MVNORM_H_
#include <memory>
#include <vector>
#include "mace/core/op_context.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/tensor.h"
#include "mace/ops/opencl/helper.h"
#include "mace/ops/opencl/mvnorm.h"
namespace mace {
namespace ops {
namespace opencl {
namespace image {
class MVNormKernel : public OpenCLMVNormKernel {
public:
explicit MVNormKernel(bool normalize_variance_,
bool across_channels, float eps);
~MVNormKernel() = default;
MaceStatus Compute(
OpContext *context, const Tensor *input, Tensor *output) override;
private:
void CheckImage(OpContext *context, const DataType dt,
const std::vector<index_t> &square_shape,
const std::vector<index_t> &mean_shape);
MaceStatus ExecuteMeanNormKernel(OpContext *context,
OpenCLRuntime *runtime,
const uint32_t (&gws)[3],
const Tensor *input,
Tensor *output);
MaceStatus ExecuteVarianceNormStep1Kernel(OpContext *context,
OpenCLRuntime *runtime,
const uint32_t (&gws)[3],
const Tensor *input);
MaceStatus ExecuteVarianceNormStep2Kernel(OpContext *context,
OpenCLRuntime *runtime,
const uint32_t (&gws)[3],
const Tensor *input,
Tensor *output);
private:
bool normalize_variance_;
bool across_channels_;
float eps_;
cl::Kernel kernel_step1_;
uint32_t kwg_size_step1_;
cl::Kernel kernel_step2_;
uint32_t kwg_size_step2_;
// the cache of (X - EX)^2
std::unique_ptr<Image> square_image_;
// the cache of EX
std::unique_ptr<Image> mean_image_;
};
} // namespace image
} // namespace opencl
} // namespace ops
} // namespace mace
#endif // MACE_OPS_OPENCL_IMAGE_MVNORM_H_
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_OPENCL_LPNORM_H_
#define MACE_OPS_OPENCL_LPNORM_H_
#include "mace/public/mace.h"
#include "mace/utils/math.h"
namespace mace {
class OpContext;
class Tensor;
namespace ops {
class OpenCLLpNormKernel {
public:
virtual MaceStatus Compute(
OpContext *context,
const Tensor *input,
Tensor *output) = 0;
MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLLpNormKernel);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_OPENCL_LPNORM_H_
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_OPENCL_MVNORM_H_
#define MACE_OPS_OPENCL_MVNORM_H_
#include "mace/public/mace.h"
#include "mace/utils/math.h"
namespace mace {
class OpContext;
class Tensor;
namespace ops {
class OpenCLMVNormKernel {
public:
virtual MaceStatus Compute(
OpContext *context,
const Tensor *input,
Tensor *output) = 0;
MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLMVNormKernel);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_OPENCL_MVNORM_H_
......@@ -56,12 +56,13 @@ void BiasAdd::AddBias(const OpContext *context,
const index_t width = output->dim(3);
const index_t image_size = height * width;
auto bias_b = bias->dim_size() == 1 ? 0 : bias->shape()[1];
for (index_t b = 0; b < batch; ++b) {
for (index_t c = 0; c < channels; ++c) {
const index_t offset = (b * channels + c) * image_size;
auto input_ptr = input_data + offset;
auto output_ptr = output_data + offset;
const float bias = bias_data[c];
const float bias = bias_data[bias_b * channels + c];
for (index_t i = 0; i < image_size; ++i) {
(*output_ptr++) = (*input_ptr++) + bias;
......
......@@ -46,8 +46,10 @@ extern void RegisterDelay(OpRegistryBase *op_registry);
extern void RegisterInferConv2dShape(OpRegistryBase *op_registry);
extern void RegisterKaldiBatchNorm(OpRegistryBase *op_registry);
extern void RegisterLocalResponseNorm(OpRegistryBase *op_registry);
extern void RegisterLpNorm(OpRegistryBase *op_registry);
extern void RegisterLSTMNonlinear(OpRegistryBase *op_registry);
extern void RegisterMatMul(OpRegistryBase *op_registry);
extern void RegisterMVNorm(OpRegistryBase *op_registry);
extern void RegisterOneHot(OpRegistryBase *op_registry);
extern void RegisterPad(OpRegistryBase *op_registry);
extern void RegisterPadContext(OpRegistryBase *op_registry);
......@@ -121,8 +123,10 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops::RegisterInferConv2dShape(this);
ops::RegisterKaldiBatchNorm(this);
ops::RegisterLocalResponseNorm(this);
ops::RegisterLpNorm(this);
ops::RegisterLSTMNonlinear(this);
ops::RegisterMatMul(this);
ops::RegisterMVNorm(this);
ops::RegisterOneHot(this);
ops::RegisterPad(this);
ops::RegisterPadContext(this);
......
......@@ -159,8 +159,10 @@ void RegisterReshape(OpRegistryBase *op_registry) {
auto tensor_shape_info = context->tensor_shape_info();
const std::string &input_0 = op->input(0);
if (4 == op->output_shape(0).dims_size() &&
4 == tensor_shape_info->at(input_0).size()) {
const auto out_dims_size =
op->output_shape(0).dims_size();
if (4 == tensor_shape_info->at(input_0).size()
&& (out_dims_size == 4 || out_dims_size == 2)) {
return {DeviceType::CPU, DeviceType::GPU};
}
return {DeviceType::CPU};
......
......@@ -82,12 +82,13 @@ class SoftmaxOp<DeviceType::CPU, float> : public Operation {
index_t batch_stride = class_size;
index_t batch_size = batch_stride * input->dim(0);
Buffer cache_buffer(context->device()->allocator());
MACE_RETURN_IF_ERROR(cache_buffer.Allocate(hw_size * sizeof(float)));
auto cache_buffer = context->device()->scratch_buffer();
cache_buffer->Rewind();
MACE_RETURN_IF_ERROR(cache_buffer->GrowSize(hw_size * sizeof(float)));
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
float std_lowest = std::numeric_limits<float>::lowest();
float *cache_ptr = cache_buffer.mutable_data<float>();
float *cache_ptr = cache_buffer->mutable_data<float>();
for (index_t b_offset = 0;
b_offset < batch_size; b_offset += batch_stride) {
......
......@@ -84,7 +84,7 @@ def encrypt_opencl_codegen(cl_kernel_dir, output_path):
for file_name in os.listdir(cl_kernel_dir):
file_path = os.path.join(cl_kernel_dir, file_name)
module_key = get_module_key(file_name)
if len(module_key) > 0:
if module_key is not None and len(module_key) > 0:
with open(file_path, "r") as f:
code_str = ""
headers = []
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment