Skip to content

Instantly share code, notes, and snippets.

@corba777
Created October 30, 2019 06:16
Show Gist options
  • Save corba777/310fd23605366a0e873975b213ce1a58 to your computer and use it in GitHub Desktop.
Save corba777/310fd23605366a0e873975b213ce1a58 to your computer and use it in GitHub Desktop.
Differentiable argmax
Display the source blob
Display the rendered blob
Raw
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
#!/usr/bin/env python
# coding: utf-8
# In[1]:
using Pkg
# In[ ]:
Pkg.add("ForwardDiff")
# In[3]:
using ForwardDiff
# In[4]:
function example(x)
if x==0
return x
end
return sin(pi/x)
end
# In[5]:
function argmax(f,x)
max=f(x[1])
for i in 2:length(x)
el=x[i]
if f(el) > f(max)
max=el
end
end
return max
end
# In[6]:
argmax(example,[1.0,2.0,3.0, 4.0])
# In[7]:
ForwardDiff.gradient(x->argmax(example,x),[1.0,2.0,3.0, 4.0])
# In[ ]:
Pkg.add("ReverseDiff")
# In[9]:
using ReverseDiff
# In[10]:
ReverseDiff.gradient(x->argmax(example,x),[1.0,2.0,3.0, 4.0])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment