Last active
April 18, 2021 21:53
-
-
Save AntsiferovMaxim/35ed5908e9cd064c443b7c7644d1f9f7 to your computer and use it in GitHub Desktop.
Hungarian
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
export function cosineMeasure(vectors: number[][]): number[][] { | |
return vectors.reduce<number[][]>((acc, user) => { | |
const weights = vectors.map(item => { | |
return scalarProduct(user, item) / Math.sqrt(scalarProduct(user, user)) / Math.sqrt(scalarProduct(item, item)); | |
}); | |
return [...acc, weights]; | |
}, []); | |
} | |
function scalarProduct(a: number[], b: number[]) { | |
return a.reduce((acc, item, index) => { | |
return item * b[index] + acc; | |
}, 0); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
function memberToVector(interests: InterestsEntity[], member: GroupMembersEntity): Array<number> { | |
const age = dayjs().year() - dayjs(member.user.details.dob).year(); | |
const sex = member.user.details.sex === 'f' ? 0 : 1; | |
const vector = [age, sex]; | |
return interests.reduce((acc, interest) => { | |
return [ | |
...acc, | |
member.user.interests | |
.find(item => item.id === interest.id) ? 1 : 0, | |
]; | |
}, vector); | |
} | |
const interests = await this.interestsService.getAll(); | |
const group = await this.groupsService.getGroup(groupId); | |
const vectors = group.members | |
.map(item => memberToVector(interests, item)); | |
const weights = hungarian(cosineMeasure(vectors), { | |
mode: HungarianMode.MAXIMUM, | |
intersection: true, | |
}); | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
export type Vector = Array<number> | |
export type Matrix = Array<Vector> | |
export enum HungarianMode { | |
MAXIMUM, | |
MINIMUM | |
} | |
export type HungarianOptions = { | |
mode?: HungarianMode; | |
intersection?: boolean; | |
} | |
function first<T>(arr: T[]): T { | |
return arr[0]; | |
} | |
function prepare(matrix: Matrix) { | |
const maxStrVals = matrix.map(row => Math.max(...row)); | |
const payload = matrix; | |
for (let i = 0; i < payload.length; i++) { | |
for (let j = 0; j < (first(payload).length); j++) { | |
if (payload[i][j] !== null) { | |
payload[i][j] = maxStrVals[i] - payload[i][j]; | |
} | |
} | |
} | |
return payload; | |
} | |
function removeIntersection(matrix: Matrix) { | |
return matrix.map((vector, i) => { | |
return vector.map((cell, j) => { | |
return i === j ? null : cell; | |
}); | |
}); | |
} | |
export function hungarian(matrix: Matrix, options: HungarianOptions = {}): Array<{ x: number, y: number }> { | |
if (options.mode === HungarianMode.MAXIMUM) { | |
matrix = prepare(matrix); | |
} | |
if (options.intersection === true) { | |
matrix = removeIntersection(matrix); | |
} | |
const height = matrix.length; | |
const width = first(matrix).length; | |
const u: Vector = new Array(height).fill(0); | |
const v: Vector = new Array(width).fill(0); | |
const markIndices: Vector = new Array(width).fill(-1); | |
for (let i = 0; i < height; i++) { | |
const links = new Array(width).fill(-1); | |
const mins = new Array(width).fill(Number.MAX_SAFE_INTEGER); | |
const visited = new Array(width).fill(0); | |
let markedI = i; | |
let markedJ = -1; | |
let j = 0; | |
while (markedI !== -1) { | |
j = -1; | |
for (let j1 = 0; j1 < width; j1++) { | |
if (visited[j1] !== 1) { | |
if (matrix[markedI][j1] !== null && matrix[markedI][j1] - u[markedI] - v[j1] < mins[j1]) { | |
mins[j1] = matrix[markedI][j1] - u[markedI] - v[j1]; | |
links[j1] = markedJ; | |
} | |
if (j == -1 || mins[j1] < mins[j]) { | |
j = j1; | |
} | |
} | |
} | |
const delta = mins[j]; | |
for (let j1 = 0; j1 < width; j1++) { | |
if (visited[j1] === 1) { | |
u[markIndices[j1]] += delta; | |
v[j1] -= delta; | |
} else { | |
mins[j1] -= delta; | |
} | |
} | |
u[i] += delta; | |
visited[j] = 1; | |
markedJ = j; | |
markedI = markIndices[j]; | |
} | |
for (; links[j] != -1; j = links[j]) { | |
markIndices[j] = markIndices[links[j]]; | |
} | |
markIndices[j] = i; | |
} | |
const result = []; | |
for (let j = 0; j < width; j++) { | |
if (markIndices[j] != -1) { | |
result.push({ x: j, y: markIndices[j] }); | |
} | |
} | |
return result; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment