Skip to content

Instantly share code, notes, and snippets.

@St-Maxwell
Created October 19, 2021 05:21
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save St-Maxwell/0a936b03ecf99e284a05d10dd994516e to your computer and use it in GitHub Desktop.
Save St-Maxwell/0a936b03ecf99e284a05d10dd994516e to your computer and use it in GitHub Desktop.
!> reference: https://github.com/KT19/automatic_differentiation
module backward
use iso_fortran_env, only: real64
implicit none
type :: node
private
real(kind=real64) :: val
real(kind=real64) :: grad = 0._real64
type(pair), dimension(:), pointer :: parents
contains
procedure, pass :: get_val => node_get_v
procedure, pass :: get_grad => node_get_g
procedure, pass :: backward_head
procedure, pass :: backward_next
generic :: backward => backward_head, backward_next
end type
type :: pair
type(node), pointer :: node_ => null()
real(kind=real64) :: local_grad
end type
!>==============================================================================
interface operator(+)
module procedure :: node_add
end interface
interface operator(*)
module procedure :: num_mul_node
module procedure :: node_mul_node
end interface
interface exp
module procedure :: exp_node
end interface
contains
function ad_node(val) result(new_node)
real(kind=real64), intent(in) :: val
type(node), pointer :: new_node
allocate(new_node)
new_node%val = val
end function ad_node
function node_get_v(this) result(v)
class(node), intent(in) :: this
real(kind=real64) :: v
v = this%val
end function node_get_v
function node_get_g(this) result(g)
class(node), intent(in) :: this
real(kind=real64) :: g
g = this%grad
end function node_get_g
function node_add(node1, node2) result(new_node)
type(node), intent(in), target :: node1
type(node), intent(in), target :: node2
type(node), pointer :: new_node
allocate(new_node)
new_node%val = node1%val + node2%val
allocate(new_node%parents(2))
new_node%parents(1)%node_ => node1
new_node%parents(1)%local_grad = 1._real64
new_node%parents(2)%node_ => node2
new_node%parents(2)%local_grad = 1._real64
end function node_add
function num_mul_node(num, node1) result(new_node)
real(kind=real64), intent(in) :: num
type(node), intent(in), target :: node1
type(node), pointer :: new_node
allocate(new_node)
new_node%val = num * node1%val
allocate(new_node%parents(1))
new_node%parents(1)%node_ => node1
new_node%parents(1)%local_grad = num
end function num_mul_node
function node_mul_node(node1, node2) result(new_node)
type(node), intent(in), target :: node1
type(node), intent(in), target :: node2
type(node), pointer :: new_node
allocate(new_node)
new_node%val = node1%val * node2%val
allocate(new_node%parents(2))
new_node%parents(1)%node_ => node1
new_node%parents(1)%local_grad = node2%val
new_node%parents(2)%node_ => node2
new_node%parents(2)%local_grad = node1%val
end function node_mul_node
function exp_node(node1) result(new_node)
type(node), intent(in), target :: node1
type(node), pointer :: new_node
allocate(new_node)
new_node%val = exp(node1%val)
allocate(new_node%parents(1))
new_node%parents(1)%node_ => node1
new_node%parents(1)%local_grad = new_node%val
end function exp_node
subroutine backward_head(this)
class(node), intent(inout) :: this
!! locals
integer :: i
this%grad = 1._real64
if (associated(this%parents)) then
do i = 1, size(this%parents)
call this%parents(i)%node_%backward(this%parents(i)%local_grad)
end do
end if
end subroutine backward_head
subroutine backward_next(this, out)
class(node), intent(inout) :: this
real(kind=real64), intent(in) :: out
!! locals
real(kind=real64) :: sum
real(kind=real64) :: local_grad
integer :: i
sum = out
if (associated(this%parents)) then
do i = 1, size(this%parents)
local_grad = out * this%parents(i)%local_grad
call this%parents(i)%node_%backward(local_grad)
sum = sum + local_grad
end do
end if
this%grad = this%grad + sum
end subroutine backward_next
end module backward
program main
use backward
implicit none
type(node), pointer :: a, b, c, y
a => ad_node(2._real64)
b => ad_node(1._real64)
c => ad_node(0._real64)
y => (a + b * b) * exp(c)
write(*,*) "y = ", y%get_val()
call y%backward()
write(*,*) "y/a = ", a%get_grad()
write(*,*) "y/b = ", b%get_grad()
write(*,*) "y/c = ", c%get_grad()
end program main
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment