Created
September 10, 2023 05:33
-
-
Save barron9/d6d889e5ebab92c1eb9e9d7bb33a7e1c to your computer and use it in GitHub Desktop.
attention
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <iostream> | |
#include <vector> | |
#include <cmath> | |
#include "attention.h" | |
// Function to compute the attention weights | |
std::vector<double> computeAttentionWeights(const std::vector<double>& query, const std::vector< std::vector<double> >& keys) { | |
int numKeys = keys.size(); | |
std::vector<double> attentionWeights(numKeys, 0.0); | |
double totalWeight = 0.0; | |
for (int i = 0; i < numKeys; ++i) { | |
// Calculate similarity score (e.g., dot product) between query and each key | |
double similarity = 0.0; | |
for (int j = 0; j < query.size(); ++j) { | |
similarity += query[j] * keys[i][j]; | |
// std::cout << "smiliaryity: " << similarity << std::endl; | |
} | |
// Apply softmax to get attention weight | |
attentionWeights[i] = exp(similarity); | |
std::cout << "attweight: " << attentionWeights[i] << std::endl; | |
totalWeight += attentionWeights[i]; | |
} | |
// Normalize attention weights using softmax | |
for (int i = 0; i < numKeys; ++i) { | |
attentionWeights[i] /= totalWeight; | |
} | |
return attentionWeights; | |
} | |
int main() { | |
std::vector<double> query = {0, 0.1, 0}; | |
std::vector<std::vector<double> > keys = {{0.8, 0.2, 0.3}, {0.1, 0.7, 0.5}, {0.1, 0.7, 0.5}}; | |
// Compute attention weights | |
std::vector<double> attentionWeights = computeAttentionWeights(query, keys); | |
std::double_t totlaattweight = 0; | |
// Display attention weights | |
for (int i = 0; i < attentionWeights.size(); ++i) { | |
std::cout << "Attention Weight " << i << ": " << attentionWeights[i] << std::endl; | |
totlaattweight += attentionWeights[i]; | |
} | |
std::cout << "sum check : " << totlaattweight << std::endl; | |
return 0; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef ATTENTION_H | |
#define ATTENTION_H | |
#include <vector> | |
// Function to compute the attention weights | |
std::vector<double> computeAttentionWeights(const std::vector<double>& query, const std::vector<std::vector<double>>& keys); | |
#endif // ATTENTION_H |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment