Skip to content

Instantly share code, notes, and snippets.

@NHZlX
Created May 5, 2019 06:05
Show Gist options
  • Save NHZlX/b05608cb1e63ed204fbfa1b8cff4859e to your computer and use it in GitHub Desktop.
Save NHZlX/b05608cb1e63ed204fbfa1b8cff4859e to your computer and use it in GitHub Desktop.
#include <string>
#include <vector>
#include <cmath>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/utils/singleton.h"
#include "paddle/fluid/framework/executor.h"
#define Random(x) (rand()%x)
namespace paddle {
namespace inference {
namespace tensorrt {
void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place,
const platform::DeviceContext& ctx) {
auto dims = tensor->dims();
size_t num_elements = analysis::AccuDims(dims, dims.size());
PADDLE_ENFORCE_GT(num_elements, 0);
platform::CPUPlace cpu_place;
framework::LoDTensor temp_tensor;
temp_tensor.Resize(dims);
auto* temp_data = temp_tensor.mutable_data<float>(cpu_place);
for (size_t i = 0; i < num_elements; i++) {
*(temp_data + i) = Random(100) / 1000.0;
}
TensorCopySync(temp_tensor, place, tensor);
}
void PrepareInt8Input(framework::LoDTensor *tensor, float * scale) {
float max_r = 0;
std::cout << "Prepare INPUT: " << std::endl;
platform::CPUPlace cpu_place;
auto dims = tensor->dims();
size_t num_elements = analysis::AccuDims(dims, dims.size());
PADDLE_ENFORCE_GT(num_elements, 0);
auto* temp_data = tensor->mutable_data<float>(cpu_place);
for (size_t i = 0; i < num_elements; i++) {
float temp_d = (Random(2000)/10. - 100);
if (std::abs(temp_d) > max_r) {max_r = std::abs(temp_d);}
*(temp_data + i) = temp_d;
std::cout << temp_d << std::endl;
}
*scale = max_r;
}
void PrepareParam(framework::LoDTensor *tensor) {
std::cout << "Prepare FC param: " << std::endl;
platform::CPUPlace cpu_place;
auto dims = tensor->dims();
size_t num_elements = analysis::AccuDims(dims, dims.size());
PADDLE_ENFORCE_GT(num_elements, 0);
int shape_s = num_elements;
auto* temp_data = tensor->mutable_data<float>(cpu_place);
for (size_t i = 0; i < shape_s; i++) {
float temp_d = (Random(400)/100. - 1.);
*(temp_data + i) = temp_d;
std::cout << temp_d << std::endl;
}
}
void PrepareFcParam(float *src, float *dst, int c, int k) {
std::cout << "Preprea Fc Param: " << std::endl;
for (int h = 0; h < k; h++) {
for (int w = 0; w < c; w++) {
dst[h * c + w * 1] = src[h * 1 + w * k];
std::cout << dst[h * c + w * 1] << std::endl;
}
}
}
void PrepareInt8Param(framework::LoDTensor *tensor, framework::LoDTensor *tensor_int8, float *scale) {
std::cout << "Prepare PARAM: " << std::endl;
float max_r = 0;
platform::CPUPlace cpu_place;
auto dims = tensor->dims();
size_t num_elements = analysis::AccuDims(dims, dims.size());
PADDLE_ENFORCE_GT(num_elements, 0);
int cout = dims[0];
int shape_s = num_elements / cout;
auto* temp_data = tensor->mutable_data<float>(cpu_place);
auto* temp_data_int8 = tensor_int8->mutable_data<float>(cpu_place);
for (size_t i = 0; i < shape_s; i++) {
float temp_d = (Random(400)/100. - 1.);
if (std::abs(temp_d) > max_r) {max_r = std::abs(temp_d);}
*(temp_data + i) = temp_d;
std::cout << temp_d << std::endl;
}
for (size_t i = 0; i < shape_s; i++) {
float temp_d = *(temp_data + i);
*(temp_data_int8 + i) = std::round(temp_d / max_r * 127);
std::cout << *(temp_data_int8 + i) << std::endl;
}
for (size_t i = shape_s; i < num_elements; i++) {
*(temp_data + i) = *(temp_data + (i % shape_s));
*(temp_data_int8 + i) = *(temp_data_int8 + (i % shape_s));
}
*scale = max_r;
}
void DeclVar(const std::string& name, const std::vector<int> dim_vec, framework::Scope &scope) {
platform::CUDAPlace place;
platform::CUDADeviceContext ctx(place);
auto* x = scope.Var(name);
auto* x_tensor = x->GetMutable<framework::LoDTensor>();
x_tensor->Resize(framework::make_ddim(dim_vec));
RandomizeTensor(x_tensor, place, ctx);
}
void DeclVar(const std::string& name, const framework::LoDTensor & tensor, framework::Scope &scope) {
platform::CUDAPlace place;
auto* x = scope.Var(name);
auto dims = tensor.dims();
auto* x_tensor = x->GetMutable<framework::LoDTensor>();
x_tensor->Resize(dims);
TensorCopySync(tensor, place, x_tensor);
}
void PrepareTRT() {
}
void PrepareFluid() {
// Prepare fluid int8 program
platform::CUDAPlace cuda_place(0);
platform::CUDADeviceContext cuda_ctx(cuda_place);
framework::Executor executor(cuda_place);
framework::Scope scope;
framework::ProgramDesc program_desc;
framework::BlockDesc *main_block = program_desc.MutableBlock(framework::kRootBlockIndex);
float quant_scale;
platform::CPUPlace cpu_place;
framework::LoDTensor temp_tensor;
temp_tensor.Resize(framework::make_ddim({1, 1, 2, 2}));
temp_tensor.mutable_data<float>(cpu_place);
PrepareInt8Input(&temp_tensor, &quant_scale);
framework::LoDTensor in_scale_tensor;
in_scale_tensor.Resize(framework::make_ddim({1}));
auto *temp_data = in_scale_tensor.mutable_data<float>(cpu_place);
temp_data[0] = quant_scale;
std::cout << "quant_scale: " << quant_scale << std::endl;
float weight_scale;
framework::LoDTensor filter_tensor, filter_tensor_int8;
filter_tensor.Resize(framework::make_ddim({2, 1, 2, 2}));
filter_tensor.mutable_data<float>(cpu_place);
filter_tensor_int8.Resize(framework::make_ddim({2, 1, 2, 2}));
filter_tensor_int8.mutable_data<float>(cpu_place);
PrepareInt8Param(&filter_tensor, &filter_tensor_int8, &weight_scale);
// fc weight
framework::LoDTensor fc_weight_tensor;
fc_weight_tensor.Resize(framework::make_ddim({2, 2}));
fc_weight_tensor.mutable_data<float>(cpu_place);
PrepareParam(&fc_weight_tensor);
DeclVar("quant_x", temp_tensor, scope);
DeclVar("quant_x_scale", in_scale_tensor, scope);
DeclVar("quant_out", {1, 1, 2, 2}, scope);
DeclVar("quant_out_scale", in_scale_tensor, scope);
DeclVar("conv_filter", filter_tensor_int8, scope);
DeclVar("conv_output", {1, 2, 1, 1}, scope);
// DeclVar("dequant_x", {1, 2, 1, 1}, scope);
DeclVar("dequant_out", {1, 2, 1, 1}, scope);
DeclVar("fc_weight", fc_weight_tensor, scope);
DeclVar("fc_out", {1, 2}, scope);
auto* fake_q_desc = main_block->AppendOp();
// fake_quant
fake_q_desc->SetType("fake_quantize_range_abs_max");
fake_q_desc->SetInput("X", {"quant_x"});
fake_q_desc->SetInput("InScale", {"quant_x_scale"});
fake_q_desc->SetOutput("Out", {"quant_out"});
fake_q_desc->SetOutput("OutScale", {"quant_out_scale"});
int window_size = 10000;
int bit_length = 8;
bool is_test = true;
fake_q_desc->SetAttr("window_size", window_size);
fake_q_desc->SetAttr("bit_length", bit_length);
fake_q_desc->SetAttr("is_test", is_test);
auto* desc = main_block->AppendOp();
desc->SetType("conv2d");
desc->SetInput("Input", {"quant_out"});
desc->SetInput("Filter", {"conv_filter"});
desc->SetOutput("Output", {"conv_output"});
const std::vector<int> strides({1, 1});
const std::vector<int> paddings({0, 0});
const std::vector<int> dilations({1, 1});
const int groups = 1;
desc->SetAttr("strides", strides);
desc->SetAttr("paddings", paddings);
desc->SetAttr("dilations", dilations);
desc->SetAttr("groups", groups);
auto* fake_dq_desc = main_block->AppendOp();
// fake_dquant
fake_dq_desc->SetType("fake_dequantize_max_abs");
fake_dq_desc->SetInput("X", {"conv_output"});
fake_dq_desc->SetInput("Scale", {"quant_x_scale"});
fake_dq_desc->SetOutput("Out", {"dequant_out"});
// need change here.
float max_range = float(127 * 127) / weight_scale;
std::cout << "max_range: " << max_range << std::endl;
fake_dq_desc->SetAttr("max_range", max_range);
// fc op
auto* fc_desc = main_block->AppendOp();
fc_desc->SetType("mul");
fc_desc->SetInput("X", {"dequant_out"});
fc_desc->SetInput("Y", {"fc_weight"});
fc_desc->SetOutput("Out", {"fc_out"});
auto ctx = executor.Prepare(program_desc, 0);
executor.RunPreparedContext(ctx.get(), &scope, false, true, true);
// auto *var = scope.FindVar("quant_out");
// auto *var = scope.FindVar("conv_filter");
auto *var = scope.FindVar("fc_out");
// auto *var = scope.FindVar("conv_output");
// auto *var = scope.FindVar("dequant_out");
auto out_tensor = var->GetMutable<framework::LoDTensor>();
std::vector<float> fluid_out;
framework::TensorToVector(*out_tensor, cuda_ctx, &fluid_out);
std::cout << "fluid output: " << std::endl;
size_t fluid_out_size = framework::product(out_tensor->dims());
for (int i = 0; i < fluid_out_size; i++) {
std::cout << fluid_out[i] << std::endl;
}
// FOR TRT
cudaStream_t stream;
PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream), 0);
TensorRTEngine engine(1, 1 << 20, true, nullptr, 0);
engine.InitNetwork();
engine.DeclareInput("conv_input", nvinfer1::DataType::kFLOAT, nvinfer1::DimsCHW(1, 2, 2));
auto* weight_data = filter_tensor.mutable_data<float>(cpu_place);
std::cout << "filter tensor num: " << filter_tensor.numel() << std::endl;
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
static_cast<size_t>(filter_tensor.numel())};
TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0};
auto* X = engine.GetITensor("conv_input");
engine.SetTensorDynamicRange(X, 38.3);
nvinfer1::DimsHW nv_ksize(2, 2);
nvinfer1::DimsHW nv_dilations(1, 1);
nvinfer1::DimsHW nv_strides(1, 1);
nvinfer1::DimsHW nv_paddings(0, 0);
auto *layer = engine.network()->addConvolution(*X, 2, nv_ksize, weight.get(), bias.get());
layer->setDilation(nv_dilations);
layer->setStride(nv_strides);
layer->setPadding(nv_paddings);
layer->setNbGroups(1);
engine.SetITensor("conv_output", layer->getOutput(0));
auto *iconv_output = layer->getOutput(0);
iconv_output->setName("conv_output");
engine.SetTensorDynamicRange(iconv_output, 9.86601);
// For Fc
framework::LoDTensor fc_weight_temp;
fc_weight_temp.Resize(fc_weight_tensor.dims());
fc_weight_temp.mutable_data<float>(cpu_place);
auto *fc_weight_data = fc_weight_tensor.mutable_data<float>(cpu_place);
auto *fc_weight_temp_data = fc_weight_temp.mutable_data<float>(cpu_place);
PrepareFcParam(fc_weight_data, fc_weight_temp_data, 2, 2);
TensorRTEngine::Weight fc_weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(fc_weight_temp_data),
static_cast<size_t>(fc_weight_temp.numel())};
TensorRTEngine::Weight fc_bias{nvinfer1::DataType::kFLOAT, nullptr, 0};
auto *fc_layer = engine.network()->addFullyConnected(*iconv_output, 2, fc_weight.get(), fc_bias.get());
engine.SetITensor("fc_output", fc_layer->getOutput(0));
engine.SetTensorDynamicRange(fc_layer->getOutput(0), 10);
engine.DeclareOutput("fc_output");
engine.FreezeNetwork();
std::vector<void*> buffers(2);
const int input_index = engine.engine()->getBindingIndex("conv_input");
const int output_index = engine.engine()->getBindingIndex("fc_output");
auto *input_var = scope.FindVar("quant_x");
auto *input_tensor = input_var->GetMutable<framework::LoDTensor>();
buffers[input_index] = static_cast<void*>(input_tensor->data<void>());
buffers[output_index] = static_cast<void*>(out_tensor->mutable_data<float>(cuda_place));
cudaStreamSynchronize(stream);
std::cout << "start Executing..." << std::endl;
engine.Execute(1, &buffers, stream);
cudaStreamSynchronize(stream);
std::vector<float> trt_out;
framework::TensorToVector(*out_tensor, cuda_ctx, &trt_out);
cudaStreamSynchronize(stream);
for (int i = 0; i < fluid_out_size; i++) {
std::cout << trt_out[i] << std::endl;
}
}
TEST(fake_int8, normal) {
PrepareFluid();
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(mul);
USE_OP(conv2d);
USE_OP(fake_quantize_range_abs_max);
USE_OP(fake_dequantize_max_abs);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment