Skip to content

Instantly share code, notes, and snippets.

@cyphunk
Forked from vladimir-ivanov/softmax.js
Last active April 6, 2024 11:21
Show Gist options
  • Save cyphunk/6c255fa05dd30e69f438a930faeb53fe to your computer and use it in GitHub Desktop.
Save cyphunk/6c255fa05dd30e69f438a930faeb53fe to your computer and use it in GitHub Desktop.
softmax function implementation in js
// Fork & examples for the one-line version by @vladimir-ivanov:
//let softmax = (arr) => (index) => Math.exp(arr[index]) / arr.map(y => Math.exp(y)).reduce((a, b) => a + b);
//
// Also see comments for improvements
function softmax(arr) {
return arr.map(function(value,index) {
return Math.exp(value) / arr.map( function(y /*value*/){ return Math.exp(y) } ).reduce( function(a,b){ return a+b })
})
}
example1=[ 0.9780449271202087,
0.01590355671942234,
0.0019390975357964635,
0.0015482910675927997,
0.0012942816829308867,
0.0006004497990943491,
0.0004827099328394979,
0.0001868270628619939 ]
softmax1=softmax(example1)
example2= [
{ prob: 0.32289665937423706, cat: '25_32' },
{ prob: 0.15404804050922394, cat: '38_43' },
{ prob: 0.03673655539751053, cat: '4_6' },
{ prob: 0.01545996405184269, cat: '48_53' },
{ prob: 0.011709162034094334, cat: '15_20' },
{ prob: 0.008010754361748695, cat: '8_13' },
{ last: true, prob: 0.0054732030257582664, cat: '60+' } ].map(function(v){return v.prob})
softmax2=softmax(example2)
example3=[ { prob: 0.125, cat: '25_32' },
{ prob: 0.125, cat: '38_43' },
{ prob: 0.125, cat: '15_20' },
{ prob: 0.125, cat: '8_13' },
{ prob: 0.125, cat: '4_6' },
{ prob: 0.125, cat: '48_53' },
{ prob: 0.125, cat: '60+' },
{ prob: 0.125, cat: '0_2' } ].map(function(v){return v.prob})
softmax3=softmax(example3)
@doug-ross
Copy link

Nice job. I was looking for a couple of practical examples and this helps!

@cyphunk
Copy link
Author

cyphunk commented Mar 26, 2017

elated to know it helped. thanks

@LeonardoCiaccio
Copy link

Nice job.

@enobufs
Copy link

enobufs commented Jul 27, 2018

I see two problems with the code.
(1) It is slow (repeating calculating the denominator for the same result)
(2) Math.exp(v) could return Infinity (when v is large), softmax would result in NaN

Here's a solution:

function softmax(arr) {
    const C = Math.max(...arr);
    const d = arr.map((y) => Math.exp(y - C)).reduce((a, b) => a + b);
    return arr.map((value, index) => { 
        return Math.exp(value - C) / d;
    })
}

@freud14
Copy link

freud14 commented Mar 1, 2021

Late to the party but here is my solution inspired by @enobufs's but without repeating the exponentiation.

function softmax(logits) {
    const maxLogit = Math.max(...logits);
    const scores = logits.map(l => Math.exp(l - maxLogit));
    const denom = scores.reduce((a, b) => a + b);
    return scores.map(s => s / denom);
}

@angrypie
Copy link

angrypie commented Feb 4, 2023

More efficient version.

  1. Avoid using Math.max(...arr) which will cause RangeError: Maximum call stack size exceeded
  2. No need in additional loop to calculate denominator.
// data could be both array or object
function softmax(data, from = 0, to = data.length) {
  let max = -Infinity; // Math.max(...data) vould crash on large array
  for (let id = from; id < to; id++) {
    if (max < data[id]) {
      max = data[id];
    }
  }
  // No need to use reduce, just sum the exps in the loop
  let sumOfExp = 0;
  const result = Array.isArray(data) ? [] : {};
  for (let id = from; id < to; id++) {
    result[id] = Math.exp(data[id] - max);
    sumOfExp += result[id];
  }
  // Finally divide by the sum of exps
  for (let id = from; id < to; id++) {
    result[id] = result[id] / sumOfExp;
  }

  return result;
}

If you still want to use prettier version you could just replace Math.max(...arr) with reduce.

function softmax(logits) {
  //find max in logits using reduce
  const maxLogit = logits.reduce((a, b) => Math.max(a, b), -Infinity);
  const scores = logits.map((l) => Math.exp(l - maxLogit));
  const denom = scores.reduce((a, b) => a + b);
  return scores.map((s) => s / denom);
}

@Dammmien
Copy link

Dammmien commented Apr 5, 2024

By any chance, does anyone have an example of the softmax derivative ? Or is able to tell me if this one is correct:

function derivative(arr) {
  const values = softmax(arr);

  return arr.map((x, i) => {
    return values[i] * (values[i] - (i === 0 ? 1 : 0));
  });
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment