Skip to content

Instantly share code, notes, and snippets.

@barron9
Created September 10, 2023 05:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save barron9/d6d889e5ebab92c1eb9e9d7bb33a7e1c to your computer and use it in GitHub Desktop.
Save barron9/d6d889e5ebab92c1eb9e9d7bb33a7e1c to your computer and use it in GitHub Desktop.
attention
#include <iostream>
#include <vector>
#include <cmath>
#include "attention.h"
// Function to compute the attention weights
std::vector<double> computeAttentionWeights(const std::vector<double>& query, const std::vector< std::vector<double> >& keys) {
int numKeys = keys.size();
std::vector<double> attentionWeights(numKeys, 0.0);
double totalWeight = 0.0;
for (int i = 0; i < numKeys; ++i) {
// Calculate similarity score (e.g., dot product) between query and each key
double similarity = 0.0;
for (int j = 0; j < query.size(); ++j) {
similarity += query[j] * keys[i][j];
// std::cout << "smiliaryity: " << similarity << std::endl;
}
// Apply softmax to get attention weight
attentionWeights[i] = exp(similarity);
std::cout << "attweight: " << attentionWeights[i] << std::endl;
totalWeight += attentionWeights[i];
}
// Normalize attention weights using softmax
for (int i = 0; i < numKeys; ++i) {
attentionWeights[i] /= totalWeight;
}
return attentionWeights;
}
int main() {
std::vector<double> query = {0, 0.1, 0};
std::vector<std::vector<double> > keys = {{0.8, 0.2, 0.3}, {0.1, 0.7, 0.5}, {0.1, 0.7, 0.5}};
// Compute attention weights
std::vector<double> attentionWeights = computeAttentionWeights(query, keys);
std::double_t totlaattweight = 0;
// Display attention weights
for (int i = 0; i < attentionWeights.size(); ++i) {
std::cout << "Attention Weight " << i << ": " << attentionWeights[i] << std::endl;
totlaattweight += attentionWeights[i];
}
std::cout << "sum check : " << totlaattweight << std::endl;
return 0;
}
#ifndef ATTENTION_H
#define ATTENTION_H
#include <vector>
// Function to compute the attention weights
std::vector<double> computeAttentionWeights(const std::vector<double>& query, const std::vector<std::vector<double>>& keys);
#endif // ATTENTION_H
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment