Skip to content

Instantly share code, notes, and snippets.

@saibotsivad
Last active February 9, 2024 14:18
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save saibotsivad/0f014b2b4638841b5c2d50eea71785d0 to your computer and use it in GitHub Desktop.
Save saibotsivad/0f014b2b4638841b5c2d50eea71785d0 to your computer and use it in GitHub Desktop.
Generate embeddings for a single file
from transformers import AutoTokenizer, AutoModel
import torch
def load_model(model_name="sentence-transformers/bert-base-nli-mean-tokens"):
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
return tokenizer, model
def generate_embedding(text, tokenizer, model):
# Tokenize and encode the text
encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
# Move to the same device as the model
encoded_input = {key: val.to(model.device) for key, val in encoded_input.items()}
# Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input)
# Perform pooling to get sentence embeddings
# Here, we take the mean of the token embeddings
sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
return sentence_embeddings
def mean_pooling(model_output, attention_mask):
token_embeddings = model_output.last_hidden_state
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sum_embeddings / sum_mask
def main():
# Load the model
tokenizer, model = load_model()
# Read the text file (change 'your_file.txt' to the path of your file)
with open('my-text.txt', 'r') as file:
text = file.read()
# Generate the embedding
embedding = generate_embedding(text, tokenizer, model)
# Print the embedding
print(embedding)
if __name__ == "__main__":
main()
Here is a really long file.
It is split into a few lines.

Here's what the console logs spit out, it's (approximately) what a normal embedding looks like:

tensor([[ 0.3598, -0.6431,  0.2248,  0.5941,  0.0356, -0.6295, -1.0597,  0.6563,
         -0.0319, -0.0967,  0.0996, -0.1298,  0.6388,  0.4989,  0.2152,  0.2226,
         -0.2080, -0.4710,  0.0835, -0.9761,  0.3886, -0.0234, -0.2117, -0.9203,
          0.1164, -0.1989, -0.3203, -1.7480, -0.8309,  0.1709,  0.1579, -1.0378,
          0.4107, -0.0329,  0.1906, -0.3012,  0.6966,  0.0679,  0.1450, -0.7598,
          1.6345,  0.0647,  0.0354,  0.1917, -0.3716, -0.5341,  0.0738, -0.4349,
         -0.3264, -1.1624, -0.7193,  0.0545,  0.7827,  0.8299, -0.2603,  0.2680,
          0.5260, -0.3287,  0.2574,  0.1228,  0.4074, -0.5628, -0.0300,  0.6071,
         -0.4191,  0.2277,  0.0277,  0.0777, -1.1525, -0.1893,  0.6827, -0.6947,
          0.0097,  0.0022, -0.2517, -0.0263,  0.0759,  0.3250,  0.2670,  0.2543,
         -0.2644,  0.2461,  0.4173,  0.2967, -0.5933,  0.0595, -0.2351,  0.0311,
         -0.6753,  0.8660,  0.6851, -0.4741,  0.4953, -1.0187, -0.0409, -0.7390,
         -0.2299,  0.0800, -0.2143,  0.3457, -1.0572, -0.7863,  0.4777,  0.0642,
         -0.3118,  0.5427, -0.0932,  1.8155,  0.9161,  0.0089, -0.0264,  0.2030,
         -0.0400, -0.4526,  0.1206,  0.0374, -0.4002,  0.3120,  0.4278,  0.4621,
         -0.3146,  0.8676,  0.0373, -0.5180,  0.3274, -0.6678,  1.1160, -0.3185,
         -0.1580, -0.6483, -0.0136,  0.7927,  0.3682,  0.0373, -0.5940, -0.0156,
          0.0127,  0.4357,  0.3584, -1.0232, -0.9389,  0.5583, -0.3182,  0.6521,
         -0.4210, -0.8418, -0.2137, -1.1463, -0.7520,  0.0458, -0.2448,  0.3654,
         -0.3619,  0.1821,  0.0998,  0.7311, -0.0192, -0.1821,  0.4720,  0.0330,
         -0.2695,  0.6548,  0.2653,  0.4125,  0.2893, -0.2815,  0.2906, -0.3365,
          0.0171, -0.0173,  0.3572, -0.1843,  0.6494, -0.5565, -1.6786,  0.2106,
          0.8416,  0.6087,  0.5923, -0.7594,  0.1102,  0.6505,  0.5894, -0.1603,
          0.1166, -0.2873,  0.9530,  0.0923, -0.1723, -0.6234,  0.1372, -0.7313,
         -0.5799, -0.5733,  0.1458,  0.6207, -0.1276,  0.7081,  0.2379,  0.4959,
          0.0982, -0.1029,  0.7806, -0.9553, -0.1870,  0.7831,  0.3079,  0.5296,
         -0.6406, -0.1469,  0.1059,  0.2250,  0.3847, -0.0069,  0.6547,  0.0683,
          0.9591, -0.9816,  1.0051, -0.0291,  1.4170, -0.6160,  0.2116,  1.4528,
         -0.3860,  0.8529,  0.6833,  0.2921, -0.2289, -0.6458,  0.0098,  0.0593,
         -0.4369,  0.4969, -0.0695, -0.7720, -0.1353, -0.1844,  0.8544,  0.0941,
          0.9051,  0.6603,  1.0518,  0.1834, -0.8135, -0.2445,  0.7994, -0.1304,
          0.0453,  0.2448,  0.2235, -0.2844,  0.4941, -0.4929,  0.1162, -0.2176,
          0.1957, -0.5253, -0.0646, -0.4443, -0.1799, -0.1503, -0.2053,  0.0781,
          0.0256,  0.7055, -0.6884, -0.1568, -0.0857,  0.4855,  0.0462, -0.2110,
         -0.6572, -0.5569,  0.1866, -0.4730,  1.1694, -0.3058, -0.2104,  0.1176,
          0.1918,  0.5162, -0.8293, -0.3402, -0.4512, -0.1945,  0.6235, -0.2549,
         -0.1693, -0.7550, -0.1642,  0.6426, -0.8561, -1.4076,  1.1521, -0.5170,
          0.8410, -0.7744, -0.3512,  0.5524, -0.2440,  0.2428, -0.5795, -0.1599,
         -0.3943, -0.0919, -0.5700,  0.4861, -2.1757, -0.1238, -1.3239, -0.2197,
         -1.1079, -0.2691, -0.6506,  0.2804, -0.9580, -1.0453, -0.1771, -0.5929,
         -0.8236,  0.3188,  0.6565,  0.3528, -0.1101,  1.3016,  0.4030,  0.4342,
         -0.7698, -0.5112,  0.2740, -0.5598,  0.5032,  1.3626,  0.0374,  0.4998,
         -0.6309,  0.1056,  0.4754,  0.4319, -0.1939, -0.2571, -0.3582, -1.2912,
          2.1165, -0.5386, -0.3432, -1.0393,  0.3424, -0.5556,  0.1276, -0.1274,
         -0.1488,  0.0243, -0.5651, -0.6598,  0.1229,  0.3170, -0.4232,  0.5345,
          0.0581, -0.3876, -0.3870, -0.1935,  0.5393,  0.5018,  0.3986,  0.2171,
          0.6141, -0.1721,  0.4467, -0.2666, -1.0618, -0.0161,  0.2183, -0.5612,
          0.5713,  0.2665,  0.1970, -0.0081,  0.9817, -0.3943,  0.5658, -0.4855,
          0.3874,  0.4209,  0.7225,  0.3442, -0.3855, -0.1206,  1.3287, -0.0738,
          0.1766, -1.3225, -0.1951, -0.8327,  0.1533,  0.9166, -0.4262,  0.0177,
         -0.6036,  0.6975,  0.6321, -0.8706, -0.1117,  0.7221,  0.6984,  0.0099,
         -0.0314, -0.6597, -1.2704, -0.4658,  0.0056, -0.1986, -0.5428,  0.5613,
         -0.0200,  0.0175,  0.1854,  0.9906, -0.4021, -0.0960,  0.2548, -0.3219,
         -0.4782,  0.7919,  0.2783,  0.1568, -0.6864, -0.0453, -0.0079, -0.6010,
         -0.2312,  1.4039,  0.7833, -0.4505,  0.1364, -1.1572, -0.0304, -0.6509,
          0.6879, -0.7163, -0.5838, -0.3496, -0.8199,  0.2916,  0.0761,  0.6127,
          0.1892,  0.0759,  0.2042,  0.7564,  0.1057,  0.6561,  0.2377,  0.4478,
          0.1524, -0.7142, -0.1219, -0.8800,  0.3853,  0.0939,  0.0619,  0.0540,
          0.5627,  1.0710, -0.6673, -0.4474, -0.0694, -0.1425,  0.1135,  0.0525,
          0.5041,  0.1305,  0.0592,  1.2962, -1.6540,  0.7567, -0.6972,  1.2791,
         -0.0326,  0.8996, -0.4286,  0.1917, -0.7624,  0.3261, -0.0059,  1.5572,
          0.5315, -0.0499,  0.6452,  0.1597, -1.0659, -0.5222, -1.2581,  1.8934,
         -1.1565,  0.6804,  0.5782,  0.3919,  0.6520,  0.3221,  0.6230, -1.0741,
          0.0556,  0.1437, -0.0798, -0.5364,  0.7755,  0.1634, -0.0766, -0.3711,
         -0.5369, -0.5326, -0.0944, -0.4604,  0.0791,  0.0845, -0.1091, -0.5540,
         -0.3457, -0.1702,  0.6618, -0.2121,  0.1459,  0.1407,  0.9729, -0.2937,
          0.7477, -0.2600, -0.3874, -0.0371, -0.5330,  0.9915,  0.1355, -1.0561,
         -0.1519, -0.5949,  0.6334, -0.9037, -0.0352,  0.6863,  0.6260, -0.2743,
         -0.2178,  0.3571,  0.2635, -0.7191,  0.1857,  0.1271, -0.0203,  0.4017,
          0.0278,  0.1912,  0.1783, -0.8482, -0.7319,  0.9125, -0.4378, -0.3579,
          0.3106,  0.2172, -0.1205,  0.8812, -0.0096, -0.8034,  0.3733, -0.2214,
         -0.2935, -0.7141,  0.2566, -0.4120,  1.2010,  0.9535, -0.6361,  0.8355,
          0.4190, -0.6999, -0.4658,  0.0761, -0.4421,  0.5789,  0.1998, -0.1864,
         -0.4985, -0.2501,  0.7810,  0.7239, -0.5065,  0.0432,  0.3928, -0.0579,
          0.0909,  0.7999, -0.1416,  0.4994, -0.1884,  0.5886, -0.5093,  0.0810,
         -0.5696, -0.7102, -1.1150, -1.3076,  0.3421, -0.3235, -0.1681,  0.2350,
         -0.0481, -0.4144,  0.6310, -1.3098,  0.6284, -0.2710, -0.0239, -0.1684,
          0.9439,  0.5446,  1.1166,  0.0858, -0.0602,  0.1464,  0.1149, -1.4083,
          0.6315,  0.1882,  0.7568, -0.0584,  0.3784,  0.4316,  0.3040,  0.0161,
         -0.8407,  0.2932,  0.1330,  0.1149, -0.9688,  0.5631,  0.7556,  0.0111,
         -1.0334, -1.2331, -0.0095, -0.7740,  0.1805, -0.1878,  0.4778,  0.1215,
          1.1213,  0.0308, -0.1802,  0.9744, -0.1826,  0.2991,  0.1994, -1.3867,
         -0.1509,  0.7145, -0.1478,  0.4977,  0.3294, -0.4410, -0.0707,  0.1742,
          0.5555,  0.1039, -0.1263, -0.1847,  0.3194,  0.4603, -0.8106, -0.1830,
          0.2711, -0.7334,  0.6616, -0.6917,  0.5321, -0.1577, -0.2860, -1.2315,
          0.2793, -0.6592, -0.2092, -0.0925,  0.0204,  0.1376,  0.0039, -0.0411,
          0.0550, -0.3038,  0.8242,  0.4755, -0.2893, -0.9805,  0.7563,  0.7917,
          0.0997, -0.2055,  0.5729, -1.1725, -1.4878,  0.1308, -0.9293, -0.8903,
         -0.5058,  0.8389, -0.5799, -0.1361,  0.1141,  0.1441,  0.2134,  0.4176,
         -0.1517,  0.5566, -0.6643,  0.3780, -0.0453,  0.4214,  0.2067, -0.1266,
         -0.8231,  0.5038, -0.4024, -0.9425, -0.2329, -0.5868,  0.1097, -0.0425,
         -1.1413, -0.3246, -0.3487, -0.4959, -0.0109,  0.1228,  0.4950,  0.0644,
         -0.1446,  0.0225, -1.0354,  0.9152,  0.1925,  0.2048,  0.1039, -0.0317,
         -0.4251, -0.0419, -0.1854, -0.1467, -0.7265,  0.0151, -0.9076, -0.8132,
          0.1587,  0.2093, -0.5014, -0.2831, -0.6342,  0.2830,  0.6996, -0.3354,
          0.3950, -0.6848, -0.4613, -0.0535,  0.1104,  0.3454,  0.2502,  0.1805]])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment