Skip to content

Instantly share code, notes, and snippets.

@Codelaby
Last active October 25, 2023 15:49
Show Gist options
  • Save Codelaby/4dc54511809f3e4cc7c80e6cc937cf3d to your computer and use it in GitHub Desktop.
Save Codelaby/4dc54511809f3e4cc7c80e6cc937cf3d to your computer and use it in GitHub Desktop.
bm25 algoritm swift
import Foundation
struct Document: Hashable {
let id: String
let content: String
func hash(into hasher: inout Hasher) {
hasher.combine(id)
}
static func ==(lhs: Document, rhs: Document) -> Bool {
return lhs.id == rhs.id
}
}
func calculateBM25(query: String, documents: [Document]) -> [Document: Double] {
let k1 = 1.2
let b = 0.75
let queryTerms = normalizeText(query).components(separatedBy: " ")
let documentFrequencies = calculateDocumentFrequencies(documents: documents, queryTerms: queryTerms)
let averageDocumentLength = calculateAverageDocumentLength(documents: documents)
var scores: [Document: Double] = [:]
for document in documents {
let documentLength = Double(document.content.lowercased().components(separatedBy: " ").count)
var score = 0.0
for term in queryTerms {
let termFrequency = calculateTermFrequency(term: term, document: document)
let documentFrequency = documentFrequencies[term] ?? 0
let numerator = (k1 + 1) * termFrequency
let denominator = k1 * ((1 - b) + b * (documentLength / averageDocumentLength)) + termFrequency
let idf = log((Double(documents.count) - Double(documentFrequency) + 0.5) / (Double(documentFrequency) + 0.5))
score += idf * (numerator / denominator)
}
scores[document] = score
}
return scores
}
func calculateDocumentFrequencies(documents: [Document], queryTerms: [String]) -> [String: Int] {
var documentFrequencies: [String: Int] = [:]
for term in queryTerms {
for document in documents {
if document.content.lowercased().contains(term) {
documentFrequencies[term, default: 0] += 1
}
}
}
return documentFrequencies
}
func calculateTermFrequency(term: String, document: Document) -> Double {
let normalizedDocument = normalizeText(document.content)
let terms = normalizedDocument.components(separatedBy: " ")
let termCount = terms.filter { $0 == term }.count
return Double(termCount)
}
func normalizeText(_ text: String) -> String {
let normalizedText = text.lowercased()
.folding(options: .diacriticInsensitive, locale: .current)
.replacingOccurrences(of: #"[^a-z0-9\s]+"#, with: "", options: .regularExpression)
return normalizedText
}
func calculateAverageDocumentLength(documents: [Document]) -> Double {
let totalLength = documents.reduce(0) { $0 + $1.content.lowercased().components(separatedBy: " ").count }
return Double(totalLength) / Double(documents.count)
}
// Example usage
let documents = [
Document(id: "doc1", content: "This is the first document. ID: CKG, Name: Chongqing Jiangbei International Airport, City: Chongqing, City 2: Jiangbei, Country: China, Description: Opened in 1990, Chongqing Jiangbei International replaced the older Baishiyi Airport. Its three-letter code comes from the city’s former English name: Chungking. Image Credit: byeangel, Image Credit Link: https://www.flickr.com/photos/byeangel/. State: Yubei District."),
Document(id: "doc2", content: "This document is the second document. ID: LCG, Name: Aeroporto da Coruña-Alvedro, City: A Coruña, City 2: Galicia, Country: Spain, Description: Formerly known as Alvedro Airport, A Coruña Airport was inaugurated in 1963. Its airport code comes from the Spanish city of La Coruña, Galicia. Image Credit: Caneles, Image Credit Link: https://www.flickr.com/photos/94446676@N00/."),
Document(id: "doc3", content: "And this is the third document. ID: MAD, Name: Aeropuerto Adolfo Suárez Madrid-Barajas, City: Madrid, City 2: Barajas, Country: Spain, Description: Spain’s largest airport honors former Prime Minister Adolfo Suárez, but its airport code honors its home in the capital city of Madrid. Image Credit: Anh Dinh, Image Credit Link: https://www.flickr.com/photos/anhgemus-photography/. Name in English: Adolfo Suárez Madrid–Barajas Airport."),
Document(id: "doc4", content: "This is the fourth document. ID: BCN, Name: Aeropuerto de Barcelona-El Prat, City: Barcelona, City 2: El Prat de Llobregat, Country: Spain, Description: Barcelona’s first airfield was built in 1916, but a new location in El Prat was chosen in 1918. The airport now uses the code BCN which stands for Barcelona. Image Credit: Camilo Rueda López, Image Credit Link: https://www.flickr.com/photos/kozumel/."),
]
let query = "Image Credit"
let scores = calculateBM25(query: query, documents: documents)
for (document, score) in scores {
print("Document ID: \(document.id), Score: \(score)")
}
import Foundation
struct Document: Hashable {
let id: String
let content: String
func hash(into hasher: inout Hasher) {
hasher.combine(id)
}
static func ==(lhs: Document, rhs: Document) -> Bool {
return lhs.id == rhs.id
}
}
func calculateBM25(query: String, documents: [Document]) -> [Document: Double] {
let k1 = 1.2
let b = 0.75
let queryTerms = normalizeText(query).components(separatedBy: " ")
let documentFrequencies = calculateDocumentFrequencies(documents: documents, queryTerms: queryTerms)
let averageDocumentLength = calculateAverageDocumentLength(documents: documents)
var scores: [Document: Double] = [:]
for document in documents {
let documentLength = Double(document.content.lowercased().components(separatedBy: " ").count)
var score = 0.0
for term in queryTerms {
let termFrequency = calculateTermFrequency(term: term, document: document)
let documentFrequency = documentFrequencies[term] ?? 0
let numerator = (k1 + 1) * termFrequency
let denominator = k1 * ((1 - b) + b * (documentLength / averageDocumentLength)) + termFrequency
let idf = log((Double(documents.count) - Double(documentFrequency) + 0.5) / (Double(documentFrequency) + 0.5))
score += idf * (numerator / denominator)
}
scores[document] = score
}
return scores
}
func calculateDocumentFrequencies(documents: [Document], queryTerms: [String]) -> [String: Int] {
var documentFrequencies: [String: Int] = [:]
for term in queryTerms {
for document in documents {
if document.content.lowercased().contains(term) {
documentFrequencies[term, default: 0] += 1
}
}
}
return documentFrequencies
}
func calculateTermFrequency(term: String, document: Document) -> Double {
let normalizedDocument = normalizeText(document.content)
let terms = normalizedDocument.components(separatedBy: " ")
let termCount = terms.filter { $0 == term }.count
return Double(termCount)
}
func normalizeText(_ text: String) -> String {
let normalizedText = text.lowercased()
.folding(options: .diacriticInsensitive, locale: .current)
.replacingOccurrences(of: #"[^a-z0-9\s]+"#, with: "", options: .regularExpression)
return normalizedText
}
func calculateAverageDocumentLength(documents: [Document]) -> Double {
let totalLength = documents.reduce(0) { $0 + $1.content.lowercased().components(separatedBy: " ").count }
return Double(totalLength) / Double(documents.count)
}
// Example usage
// proposed search by code 'es'
let documents = [
Document(id: "doc1", content: "name: Bangladesh, dialCode: +880, code: BD, displayName: Bangladés"),
Document(id: "doc2", content: "name: Estonia, dialCode: +372, code: EE, displayName: Estonia"),
Document(id: "doc3", content: "name: French Guiana, dialCode: +594, code: GF, displayName: Guayana Francesa"),
Document(id: "doc4", content: "name: French Polynesia, dialCode: +689, code: PF, displayName: Polinesia Francesa"),
Document(id: "doc5", content: "name: Guernsey, dialCode: +44, code: GG, displayName: Guernesey"),
Document(id: "doc6", content: "name: Indonesia, dialCode: +62, code: ID, displayName: Indonesia"),
Document(id: "doc7", content: "name: Lesotho, dialCode: +266, code: LS, displayName: Lesoto"),
Document(id: "doc8", content: "name: Maldives, dialCode: +960, code: MV, displayName: Maldivas"),
Document(id: "doc9", content: "name: Micronesia, Federated States of Micronesia, dialCode: +691, code: FM, displayName: Micronesia"),
Document(id: "doc10", content: "name: Netherlands, dialCode: +31, code: NL, displayName: Países Bajos"),
Document(id: "doc11", content: "name: Palestinian Territory, Occupied, dialCode: +970, code: PS, displayName: Territorios Palestinos"),
Document(id: "doc12", content: "name: Philippines, dialCode: +63, code: PH, displayName: Filipinas"),
Document(id: "doc13", content: "name: Saint Kitts and Nevis, dialCode: +1869, code: KN, displayName: San Cristóbal y Nieves"),
Document(id: "doc14", content: "name: Saint Vincent and the Grenadines, dialCode: +1784, code: VC, displayName: San Vicente y las Granadinas"),
Document(id: "doc15", content: "name: Seychelles, dialCode: +248, code: SC, displayName: Seychelles"),
Document(id: "doc16", content: "name: Slovakia, dialCode: +421, code: SK, displayName: Eslovaquia"),
Document(id: "doc17", content: "name: Slovenia, dialCode: +386, code: SI, displayName: Eslovenia"),
Document(id: "doc18", content: "name: Spain, dialCode: +34, code: ES, displayName: España"),
Document(id: "doc19", content: "name: Swaziland, dialCode: +268, code: SZ, displayName: Esuatini"),
Document(id: "doc20", content: "name: Timor-Leste, dialCode: +670, code: TL, displayName: Timor Oriental"),
Document(id: "doc21", content: "name: United Arab Emirates, dialCode: +971, code: AE, displayName: Emiratos Árabes Unidos"),
Document(id: "doc22", content: "name: United States, dialCode: +1, code: US, displayName: Estados Unidos"),
Document(id: "doc23", content: "name: Virgin Islands, British, dialCode: +1284, code: VG, displayName: Islas Vírgenes Británicas"),
Document(id: "doc24", content: "name: Virgin Islands, U.S., dialCode: +1340, code: VI, displayName: Islas Vírgenes de EE. UU."),
]
let query = "+34"
let scores = calculateBM25(query: query, documents: documents2)
for (document, score) in scores {
print("Document ID: \(document.id), Score: \(score)")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment