Created
October 19, 2021 05:21
-
-
Save St-Maxwell/0a936b03ecf99e284a05d10dd994516e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
!> reference: https://github.com/KT19/automatic_differentiation | |
module backward | |
use iso_fortran_env, only: real64 | |
implicit none | |
type :: node | |
private | |
real(kind=real64) :: val | |
real(kind=real64) :: grad = 0._real64 | |
type(pair), dimension(:), pointer :: parents | |
contains | |
procedure, pass :: get_val => node_get_v | |
procedure, pass :: get_grad => node_get_g | |
procedure, pass :: backward_head | |
procedure, pass :: backward_next | |
generic :: backward => backward_head, backward_next | |
end type | |
type :: pair | |
type(node), pointer :: node_ => null() | |
real(kind=real64) :: local_grad | |
end type | |
!>============================================================================== | |
interface operator(+) | |
module procedure :: node_add | |
end interface | |
interface operator(*) | |
module procedure :: num_mul_node | |
module procedure :: node_mul_node | |
end interface | |
interface exp | |
module procedure :: exp_node | |
end interface | |
contains | |
function ad_node(val) result(new_node) | |
real(kind=real64), intent(in) :: val | |
type(node), pointer :: new_node | |
allocate(new_node) | |
new_node%val = val | |
end function ad_node | |
function node_get_v(this) result(v) | |
class(node), intent(in) :: this | |
real(kind=real64) :: v | |
v = this%val | |
end function node_get_v | |
function node_get_g(this) result(g) | |
class(node), intent(in) :: this | |
real(kind=real64) :: g | |
g = this%grad | |
end function node_get_g | |
function node_add(node1, node2) result(new_node) | |
type(node), intent(in), target :: node1 | |
type(node), intent(in), target :: node2 | |
type(node), pointer :: new_node | |
allocate(new_node) | |
new_node%val = node1%val + node2%val | |
allocate(new_node%parents(2)) | |
new_node%parents(1)%node_ => node1 | |
new_node%parents(1)%local_grad = 1._real64 | |
new_node%parents(2)%node_ => node2 | |
new_node%parents(2)%local_grad = 1._real64 | |
end function node_add | |
function num_mul_node(num, node1) result(new_node) | |
real(kind=real64), intent(in) :: num | |
type(node), intent(in), target :: node1 | |
type(node), pointer :: new_node | |
allocate(new_node) | |
new_node%val = num * node1%val | |
allocate(new_node%parents(1)) | |
new_node%parents(1)%node_ => node1 | |
new_node%parents(1)%local_grad = num | |
end function num_mul_node | |
function node_mul_node(node1, node2) result(new_node) | |
type(node), intent(in), target :: node1 | |
type(node), intent(in), target :: node2 | |
type(node), pointer :: new_node | |
allocate(new_node) | |
new_node%val = node1%val * node2%val | |
allocate(new_node%parents(2)) | |
new_node%parents(1)%node_ => node1 | |
new_node%parents(1)%local_grad = node2%val | |
new_node%parents(2)%node_ => node2 | |
new_node%parents(2)%local_grad = node1%val | |
end function node_mul_node | |
function exp_node(node1) result(new_node) | |
type(node), intent(in), target :: node1 | |
type(node), pointer :: new_node | |
allocate(new_node) | |
new_node%val = exp(node1%val) | |
allocate(new_node%parents(1)) | |
new_node%parents(1)%node_ => node1 | |
new_node%parents(1)%local_grad = new_node%val | |
end function exp_node | |
subroutine backward_head(this) | |
class(node), intent(inout) :: this | |
!! locals | |
integer :: i | |
this%grad = 1._real64 | |
if (associated(this%parents)) then | |
do i = 1, size(this%parents) | |
call this%parents(i)%node_%backward(this%parents(i)%local_grad) | |
end do | |
end if | |
end subroutine backward_head | |
subroutine backward_next(this, out) | |
class(node), intent(inout) :: this | |
real(kind=real64), intent(in) :: out | |
!! locals | |
real(kind=real64) :: sum | |
real(kind=real64) :: local_grad | |
integer :: i | |
sum = out | |
if (associated(this%parents)) then | |
do i = 1, size(this%parents) | |
local_grad = out * this%parents(i)%local_grad | |
call this%parents(i)%node_%backward(local_grad) | |
sum = sum + local_grad | |
end do | |
end if | |
this%grad = this%grad + sum | |
end subroutine backward_next | |
end module backward | |
program main | |
use backward | |
implicit none | |
type(node), pointer :: a, b, c, y | |
a => ad_node(2._real64) | |
b => ad_node(1._real64) | |
c => ad_node(0._real64) | |
y => (a + b * b) * exp(c) | |
write(*,*) "y = ", y%get_val() | |
call y%backward() | |
write(*,*) "y/a = ", a%get_grad() | |
write(*,*) "y/b = ", b%get_grad() | |
write(*,*) "y/c = ", c%get_grad() | |
end program main |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment