Skip to content

Instantly share code, notes, and snippets.

@nutsam
Last active September 13, 2021 13:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save nutsam/acfd3a9ce5678cc823f7e2f3150387c3 to your computer and use it in GitHub Desktop.
Save nutsam/acfd3a9ce5678cc823f7e2f3150387c3 to your computer and use it in GitHub Desktop.
IResnet.h
#include "IResNet.h"
IResNet::IResNet()
: m_width(112),
m_height(112)
{
}
IResNet::~IResNet() {}
void IResNet::initialize(std::string model_path)
{
if (!FileExists(model_path))
{
throw std::runtime_error("Model file does not exist");
}
m_net = cv::dnn::readNetFromONNX(model_path);
}
std::vector<float> IResNet::GetEmbedding(cv::Mat img){
cv::Mat preprocessedImage = Preproecess(img);
cv::Mat inputBlob = cv::dnn::blobFromImage(
preprocessedImage, 1.0, cv::Size(m_width, m_height), cv::Scalar(0, 0, 0), true);
m_net.setInput(inputBlob);
// std::vector<cv::String> output_names = {"1722"}; // test.onnx
std::vector<cv::String> output_names = {"2240"}; // iresbot100_frelu.onnx
std::vector<cv::Mat> out_blobs;
m_net.forward(out_blobs, output_names);
std::vector<float> embedding;
cv::Mat mat = out_blobs[0];
if (mat.isContinuous()) {
embedding.assign((float*)mat.data, (float*)mat.data + mat.total()*mat.channels());
}
else {
for (int i = 0; i < mat.rows; ++i) {
embedding.insert(embedding.end(), mat.ptr<float>(i), mat.ptr<float>(i)+mat.cols*mat.channels());
}
}
return embedding;
}
cv::Mat IResNet::Preproecess(cv::Mat input){
cv::Mat output;
cv::resize(input, output, cv::Size(m_width, m_height));
output.convertTo(output, CV_32FC3);
output = (output * (0.003921568627451) - 0.5) * 2.0;
return output;
}
#ifndef IRESNET_H
#define IRESNET_H
#include <string>
#include <vector>
#include <fstream>
#include <opencv2/opencv.hpp>
class IResNet {
public:
IResNet();
~IResNet();
void initialize(std::string model_path);
std::vector<float> GetEmbedding(cv::Mat img);
private:
inline bool FileExists(const std::string &name)
{
std::ifstream fhandle(name.c_str());
return fhandle.good();
}
cv::Mat Preproecess(cv::Mat input);
private:
int m_width;
int m_height;
cv::dnn::Net m_net;
};
#endif // IRESNET_H
#include <opencv2/opencv.hpp>
#include <opencv2/dnn/dnn.hpp>
#include "IResNet.h"
using namespace std;
using namespace cv;
int main(){
IResNet recognizor;
string recognizer_model = "r100_.onnx";
recognizor.initialize(recognizer_model);
return 0;
}
class MHSA(nn.Module):
def __init__(self, n_dims, width=14, height=14, heads=4):
super(MHSA, self).__init__()
self.heads = heads
self.query = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.key = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.value = nn.Conv2d(n_dims, n_dims, kernel_size=1)
self.rel_h = nn.Parameter(torch.randn([1, heads, n_dims // heads, 1, height]), requires_grad=True)
self.rel_w = nn.Parameter(torch.randn([1, heads, n_dims // heads, width, 1]), requires_grad=True)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
n_batch, C, width, height = x.size()
q = self.query(x).view(n_batch, self.heads, C // self.heads, -1)
k = self.key(x).view(n_batch, self.heads, C // self.heads, -1)
v = self.value(x).view(n_batch, self.heads, C // self.heads, -1)
content_content = torch.matmul(q.permute(0, 1, 3, 2), k)
content_position = (self.rel_h + self.rel_w).view(1, self.heads, C // self.heads, -1).permute(0, 1, 3, 2)
content_position = torch.matmul(content_position, q)
energy = content_content + content_position
attention = self.softmax(energy)
out = torch.matmul(v, attention.permute(0, 1, 3, 2))
out = out.view(n_batch, C, width, height)
return out
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment