Skip to content

Instantly share code, notes, and snippets.

@stober
Created March 1, 2012 03:05
Show Gist options
  • Save stober/1946926 to your computer and use it in GitHub Desktop.
Save stober/1946926 to your computer and use it in GitHub Desktop.
Softmax in Python
#! /usr/bin/env python
"""
Author: Jeremy M. Stober
Program: SOFTMAX.PY
Date: Wednesday, February 29 2012
Description: Simple softmax function.
"""
import numpy as np
npa = np.array
def softmax(w, t = 1.0):
e = np.exp(npa(w) / t)
dist = e / np.sum(e)
return dist
if __name__ == '__main__':
w = np.array([0.1,0.2])
print softmax(w)
w = np.array([-0.1,0.2])
print softmax(w)
w = np.array([0.9,-10])
print softmax(w)
@kigawas
Copy link

kigawas commented Jan 30, 2017

Vote for @piyushbhardwaj

A clearer version with doctest:

def softmax(x):
    '''
    >>> res = softmax(np.array([0, 200, 10]))
    >>> np.sum(res)
    1.0
    >>> np.all(np.abs(res - np.array([0, 1, 0])) < 0.0001)
    True
    >>> res = softmax(np.array([[0, 200, 10], [0, 10, 200], [200, 0, 10]]))
    >>> np.sum(res, axis=1)
    array([ 1.,  1.,  1.])
    >>> res = softmax(np.array([[0, 200, 10], [0, 10, 200]]))
    >>> np.sum(res, axis=1)
    array([ 1.,  1.])
    '''
    if x.ndim == 1:
        x = x.reshape((1, -1))
    max_x = np.max(x, axis=1).reshape((-1, 1))
    exp_x = np.exp(x - max_x)
    return exp_x / np.sum(exp_x, axis=1).reshape((-1, 1))

@AFAgarap
Copy link

AFAgarap commented Sep 6, 2017

labels = [0, 0, 0, 0, 0.68, 0.32, 0, 0, 0, 0]

%timeit softmax = np.exp([element for element in labels]) / np.sum(np.exp([element for element in labels]))
The slowest run took 5.03 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 12.2 µs per loop

@rajans
Copy link

rajans commented Mar 21, 2018

It can be simple one liner.

def softmax(x):
return np.exp(x)/np.sum(np.exp(x),axis=0)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment