Skip to content

Instantly share code, notes, and snippets.

@CaoZhongZ
Last active August 31, 2019 02:50
Show Gist options
  • Save CaoZhongZ/34c2796deef1cc8871039b3d7441f770 to your computer and use it in GitHub Desktop.
Save CaoZhongZ/34c2796deef1cc8871039b3d7441f770 to your computer and use it in GitHub Desktop.
Freeze parameters in TorchScript Graph

Freeze 'weights', 'bias', or buffers etc. in TorchScript.

MKL-DNN requires specific format for weight to do convolution faster. By freezing weight inside a TorchScript, one could embed more information of weight tensor into graph and use constant propagation to propagate its content closer to its use, possibly avoid the whole computation of transformation in runtime.

Optimized passes for MKL-DNN enabled ops

We insert ops before aten::conv2d to transform weight format in favour of MKL-DNN computation, for example.

We got an IR:

%30 : Float(*, *, *, *) = prim::GetAttr[name="weight"]
%289 : Float(*, *, *, *) = aten::conv2d(%x.1, %30, %4, %611, %612, %613, %23)

We optimize a little bit of it by adding:

%30 : Float(*, *, *, *) = prim::GetAttr[name="weight"]
%30.weight: Float(*, *, *, *) = some::reorder(%30)
%289 : Float(*, *, *, *) = aten::conv2d(%x.1, %30.weight, %4, %611, %612, %613, %23)

However the runtime overhead of transformation was still there when the graph was evaluated. We would like it to be:

%672.weight : Float(16, 3, 3, 3) = prim::Constant[value=<Tensor>]()
%30.weight: Float(*, *, *, *) = some::reorder(%672.weight)
%289 : Float(*, *, *, *) = aten::conv2d(%x.1, %30.weight, %4, %611, %612, %613, %23)

After constant propagation (also DCE) the graph would be like:

%30.weight : Tensor = prim::Constant[value=<Tensor>]()
%289 : Float(*, *, *, *) = aten::conv2d(%x.1, %30.weight, %4, %611, %612, %613, %23)

With the Tensor being an MKL-DNN opaque tensor.

Detail procedure to actually 'freeze params'

Unfortunately we couldn't freeze params in standard registered pass because we don't have reference to its Module inside it. The provided code files exposed a function called _jit_pass_freeze_params to change prim::GetAttr to Constant by poke inside a Module and grab the data out. You could freeze params like:

ConvBnRelu = ScriptedCascadedConv2dBnRelu(3, 16, 32, kernel_size = 3, stride = 1)
freezer._jit_pass_freeze_params(ConvBnRelu._c, 'forward', 'weight')
freezer._jit_pass_freeze_params(ConvBnRelu._c, 'forward', 'bias')
freezer._jit_pass_freeze_params(ConvBnRelu._c, 'forward', 'running_mean')
freezer._jit_pass_freeze_params(ConvBnRelu._c, 'forward', 'running_var')

After running this code, ConvBnRelu.graph will treat internal weight/bias/running_mean/running_var as 'Constant' instead of GetAttr primitives which will allow constant propagation pass to do optimization available.

#include <pybind11/pybind11.h>
#include <torch/csrc/jit/pybind_utils.h>
#include <torch/csrc/jit/custom_operator.h>
#include <torch/csrc/jit/operator_options.h>
#include <torch/csrc/jit/pass_manager.h>
PYBIND11_MODULE(freezer, m) {
m.def(
"_jit_pass_freeze_params",
[](const script::Module& moduleObj,
const std::string& method_name,
const std::string& param_name) {
//
//
if (param_name == std::string("weight")
|| param_name == std::string("bias")
|| param_name == std::string("running_mean")
|| param_name == std::string("running_var")) {
pyrys::FreezeParams(moduleObj, method_name, param_name);
} else {
TORCH_CHECK(false, "Invalid Param Name");
}
});
}
#include <torch/csrc/jit/ir.h>
#include <torch/csrc/utils/memory.h>
#include <torch/csrc/jit/passes/alias_analysis.h>
#include <torch/csrc/jit/constants.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include "weight_freeze.h"
#include "graph_ext.h"
namespace freeze {
//
// Just like quan-dequant process on param
// It should be put in utils, please :)
//
struct ParamValue {
Value* definition;
IValue slot;
};
//
// Recusively gather prim::GetAttr nodes and their correspond IValue
//
static void gatherParams(
const script::Module& module,
Value* module_value,
std::vector<ParamValue>& params) {
for (const Use& u : module_value->uses()) {
if (u.user->kind() != prim::GetAttr) {
continue;
}
const std::string& field = u.user->s(::attr::name);
if (const auto& sub = module.find_module(field)) {
gatherParams(*sub, u.user->output(), params);
} else if (auto slot = module.find_parameter(field)) {
params.emplace_back(ParamValue{u.user->output(), slot->value()});
} else if (auto slot
// XXX: attribute??
= const_cast<script::Module&>(module).find_attribute(field)) {
// For runnign_mean and running_var
params.emplace_back(ParamValue{u.user->output(), slot->value()});
}
}
}
//
// We don't know when to call this pass, so we need to check both symbols
// in pytorch or in mkldnn
//
bool nodeOfInterest(const Node* node) {
if (node->kind() == aten::conv2d
|| node->kind() == aten::batch_norm)
return true;
auto* nodeExt = reinterpret_cast<const NodeExt *>(node);
return nodeExt ->isConv2d() || nodeExt->isBatchNorm();
}
std::vector<ParamValue> getParamForFreeze(
script::Method& method,
const std::string& param_name) {
std::vector<ParamValue> params;
std::vector<ParamValue> ret;
gatherParams(method.owner(), method.graph()->inputs().at(0), params);
// We filter out params that should be freeze
for (const auto& param : params) {
if (!param.definition->type()->isSubtypeOf(TensorType::get()))
continue;
for(const auto& u : param.definition->uses()) {
if (nodeOfInterest(u.user)
&& u.user->schema().arguments().at(u.offset).name() == param_name) {
//
// XXX: detach from param itself
//
ret.push_back({param.definition, param.slot.toTensor().detach()});
break;
}
}
}
return ret;
}
std::vector<ParamValue> getFlagForFreeze(
script::Method& method,
const std::string& flag_name) {
std::vector<ParamValue> params;
std::vector<ParamValue> ret;
gatherParams(method.owner(), method.graph()->inputs().at(0), params);
// We filter out params that should be freeze
for (const auto& param : params) {
for(const auto& u : param.definition->uses()) {
if (u.user->kind() == prim::If
&& u.user->input()->node()->s(::attr::name) == flag_name) {
// We don't need the value, because we will change it to arbitrary one.
ret.push_back({param.definition, param.slot.toBool()});
break;
}
}
}
return ret;
}
void insertConstantForParam(
script::Method& method,
const std::string& param_name) {
const auto params = getParamForFreeze(method, param_name);
auto g = method.graph();
// Change all the accesses to constant
for (const auto& param : params) {
//
// TODO: assert ref.defined()
// to_mkldnn will disable infer a complete tensor from this point
// we might not be able to fold BN at compiling time by constant
// propagation
//
WithInsertPoint guard(param.definition->node()->next());
auto n = tryInsertConstant(*g, param.slot);
if (n) {
param.definition->replaceAllUsesWith(*n);
auto v = n.value();
v->setDebugName(v->debugName() + '.' + param_name);
}
}
}
void FreezeParams(
const script::Module& moduleObj,
const std::string& method_name,
const std::string& param_name
) {
script::Method method = moduleObj.get_method(method_name);
insertConstantForParam(method, param_name);
// TODO: DCE? CP?
}
//
// TODO: implement partial redundency elimination and move the possible
// reorder closer to its use and optimize further.
//
// XXX: Freeze a flag is not a common behavior, it's a workaround
//
TORCH_API void FreezeFlags(
const script::Module& moduleObj,
const std::string& method_name,
const std::string& flag_name,
bool value) {
script::Method method = moduleObj.get_method(method_name);
const auto params = getFlagForFreeze(method, flag_name);
auto g = method.graph();
// Change all the accesses to constant
for (const auto& param : params) {
//
// TODO: assert ref.defined()
// to_mkldnn will disable infer a complete tensor from this point
// we might not be able to fold BN at compiling time by constant
// propagation
//
WithInsertPoint guard(param.definition->node()->next());
auto n = tryInsertConstant(*g, IValue(value));
if (n) {
param.definition->replaceAllUsesWith(*n);
auto v = n.value();
v->setDebugName(v->debugName() + '.' + flag_name);
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment