Skip to content

Instantly share code, notes, and snippets.

@reikdas
Last active August 23, 2021 01:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save reikdas/ee0f2f3244afc8a6efba659232722b92 to your computer and use it in GitHub Desktop.
Save reikdas/ee0f2f3244afc8a6efba659232722b92 to your computer and use it in GitHub Desktop.
GSoC Report: INTEGRATE CUSTOM DERIVATIVES OF NUMERICAL COMPUTING ROUTINES LIKE BLAS INTO ENZYME

INTEGRATE CUSTOM DERIVATIVES OF NUMERICAL COMPUTING ROUTINES LIKE BLAS INTO ENZYME

Deliverables of the summer:

  • An LLVM pass that inlines definitions of functions into the LLVM IR source code.
  • Enzyme AdjointGenerator being able to use existing CBLAS functions to calculate the derivative of a CBLAS function.

BCPass

The BCPass is a non-Enzyme specific LLVM pass that currently resides in the Enzyme GitHub repository.

If the definition of a function isn't explicitly available in the LLVM IR, BCPass will check if the definition of the function is available as a .bc file that it knows about and try and provide the definition of the function in the LLVM IR.

How is this relevant to Enzyme? Definitions of functions from libraries such as BLAS can be made available as bitcode files which can later be inlined.

Eg:

For a simple C file -

#include <cblas.h>

extern double __enzyme_autodiff(void *, double *, double *, double *,
                                 double *);

double g(double *m, double *n) {
  double x = cblas_ddot(3, m, 1, n, 1);
  m[0] = 11.0;
  m[1] = 12.0;
  m[2] = 13.0;
  double y = x * x;
  return y;
}

int main() {
  double m[3] = {1, 2, 3};
  double m1[3] = {0, 0, 0};
  double n[3] = {4, 5, 6};
  double n1[3] = {0, 0, 0};
  double val = __enzyme_autodiff((void*)g, m, m1, n, n1);
  return 1;
}

The generated LLVM IR is -

target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

@__const.main.m = private unnamed_addr constant [3 x double] [double 1.000000e+00, double 2.000000e+00, double 3.000000e+00], align 16
@__const.main.n = private unnamed_addr constant [3 x double] [double 4.000000e+00, double 5.000000e+00, double 6.000000e+00], align 16

; Function Attrs: noinline nounwind optnone uwtable
define dso_local double @g(double* %m, double* %n) {
entry:
  %m.addr = alloca double*, align 8
  %n.addr = alloca double*, align 8
  %x = alloca double, align 8
  %y = alloca double, align 8
  store double* %m, double** %m.addr, align 8
  store double* %n, double** %n.addr, align 8
  %0 = load double*, double** %m.addr, align 8
  %1 = load double*, double** %n.addr, align 8
  %call = call double @cblas_ddot(i32 3, double* %0, i32 1, double* %1, i32 1)
  store double %call, double* %x, align 8
  %2 = load double*, double** %m.addr, align 8
  %arrayidx = getelementptr inbounds double, double* %2, i64 0
  store double 1.100000e+01, double* %arrayidx, align 8
  %3 = load double*, double** %m.addr, align 8
  %arrayidx1 = getelementptr inbounds double, double* %3, i64 1
  store double 1.200000e+01, double* %arrayidx1, align 8
  %4 = load double*, double** %m.addr, align 8
  %arrayidx2 = getelementptr inbounds double, double* %4, i64 2
  store double 1.300000e+01, double* %arrayidx2, align 8
  %5 = load double, double* %x, align 8
  %6 = load double, double* %x, align 8
  %mul = fmul double %5, %6
  store double %mul, double* %y, align 8
  %7 = load double, double* %y, align 8
  ret double %7
}

declare dso_local double @cblas_ddot(i32, double*, i32, double*, i32)

; Function Attrs: noinline nounwind optnone uwtable
define dso_local i32 @main() {
entry:
  %retval = alloca i32, align 4
  %m = alloca [3 x double], align 16
  %m1 = alloca [3 x double], align 16
  %n = alloca [3 x double], align 16
  %n1 = alloca [3 x double], align 16
  %val = alloca double, align 8
  store i32 0, i32* %retval, align 4
  %0 = bitcast [3 x double]* %m to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([3 x double]* @__const.main.m to i8*), i64 24, i1 false)
  %1 = bitcast [3 x double]* %m1 to i8*
  call void @llvm.memset.p0i8.i64(i8* align 16 %1, i8 0, i64 24, i1 false)
  %2 = bitcast [3 x double]* %n to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %2, i8* align 16 bitcast ([3 x double]* @__const.main.n to i8*), i64 24, i1 false)
  %3 = bitcast [3 x double]* %n1 to i8*
  call void @llvm.memset.p0i8.i64(i8* align 16 %3, i8 0, i64 24, i1 false)
  %arraydecay = getelementptr inbounds [3 x double], [3 x double]* %m, i32 0, i32 0
  %arraydecay1 = getelementptr inbounds [3 x double], [3 x double]* %m1, i32 0, i32 0
  %arraydecay2 = getelementptr inbounds [3 x double], [3 x double]* %n, i32 0, i32 0
  %arraydecay3 = getelementptr inbounds [3 x double], [3 x double]* %n1, i32 0, i32 0
  %call = call double @__enzyme_autodiff(i8* bitcast (double (double*, double*)* @g to i8*), double* %arraydecay, double* %arraydecay1, double* %arraydecay2, double* %arraydecay3)
  store double %call, double* %val, align 8
  ret i32 1
}

; Function Attrs: argmemonly nounwind
declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1)

; Function Attrs: argmemonly nounwind
declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1)

declare dso_local double @__enzyme_autodiff(i8*, double*, double*, double*, double*)

The definition of cblas_ddot isn't available in the LLVM IR, so Enzyme does not know how to differentiate it.

We invoke the BCPass like - clang filename.ll -Xclang -load -Xclang /path/to/BCPass-<vers>.so -mllvm -bcpath=/path/to/dir/with/bcfiles -S -emit-llvm -o -

The definition of the cblas_ddot is now available in the LLVM IR -

; ModuleID = '/home/reikdas/Enzyme/enzyme/test/BCLoader/bcloader-ddot.ll'
source_filename = "/home/reikdas/Enzyme/enzyme/test/BCLoader/bcloader-ddot.ll"
target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

@__const.main.m = private unnamed_addr constant [3 x double] [double 1.000000e+00, double 2.000000e+00, double 3.000000e+00], align 16
@__const.main.n = private unnamed_addr constant [3 x double] [double 4.000000e+00, double 5.000000e+00, double 6.000000e+00], align 16

define dso_local double @g(double* %m, double* %n) {
entry:
  %m.addr = alloca double*, align 8
  %n.addr = alloca double*, align 8
  %x = alloca double, align 8
  %y = alloca double, align 8
  store double* %m, double** %m.addr, align 8
  store double* %n, double** %n.addr, align 8
  %0 = load double*, double** %m.addr, align 8
  %1 = load double*, double** %n.addr, align 8
  %call = call double @cblas_ddot(i32 3, double* %0, i32 1, double* %1, i32 1)
  store double %call, double* %x, align 8
  %2 = load double*, double** %m.addr, align 8
  %arrayidx = getelementptr inbounds double, double* %2, i64 0
  store double 1.100000e+01, double* %arrayidx, align 8
  %3 = load double*, double** %m.addr, align 8
  %arrayidx1 = getelementptr inbounds double, double* %3, i64 1
  store double 1.200000e+01, double* %arrayidx1, align 8
  %4 = load double*, double** %m.addr, align 8
  %arrayidx2 = getelementptr inbounds double, double* %4, i64 2
  store double 1.300000e+01, double* %arrayidx2, align 8
  %5 = load double, double* %x, align 8
  %6 = load double, double* %x, align 8
  %mul = fmul double %5, %6
  store double %mul, double* %y, align 8
  %7 = load double, double* %y, align 8
  ret double %7
}

define dso_local i32 @main() {
entry:
  %retval = alloca i32, align 4
  %m = alloca [3 x double], align 16
  %m1 = alloca [3 x double], align 16
  %n = alloca [3 x double], align 16
  %n1 = alloca [3 x double], align 16
  %val = alloca double, align 8
  store i32 0, i32* %retval, align 4
  %0 = bitcast [3 x double]* %m to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([3 x double]* @__const.main.m to i8*), i64 24, i1 false)
  %1 = bitcast [3 x double]* %m1 to i8*
  call void @llvm.memset.p0i8.i64(i8* align 16 %1, i8 0, i64 24, i1 false)
  %2 = bitcast [3 x double]* %n to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %2, i8* align 16 bitcast ([3 x double]* @__const.main.n to i8*), i64 24, i1 false)
  %3 = bitcast [3 x double]* %n1 to i8*
  call void @llvm.memset.p0i8.i64(i8* align 16 %3, i8 0, i64 24, i1 false)
  %arraydecay = getelementptr inbounds [3 x double], [3 x double]* %m, i32 0, i32 0
  %arraydecay1 = getelementptr inbounds [3 x double], [3 x double]* %m1, i32 0, i32 0
  %arraydecay2 = getelementptr inbounds [3 x double], [3 x double]* %n, i32 0, i32 0
  %arraydecay3 = getelementptr inbounds [3 x double], [3 x double]* %n1, i32 0, i32 0
  %call = call double @__enzyme_autodiff(i8* bitcast (double (double*, double*)* @g to i8*), double* %arraydecay, double* %arraydecay1, double* %arraydecay2, double* %arraydecay3)
  store double %call, double* %val, align 8
  ret i32 1
}

; Function Attrs: argmemonly nounwind
declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) #0

; Function Attrs: argmemonly nounwind
declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) #0

declare dso_local double @__enzyme_autodiff(i8*, double*, double*, double*, double*)

; !!!DEFINITION OF CBLAS_DDOT!!!
; Function Attrs: noinline nounwind optnone uwtable
define dso_local double @cblas_ddot(i32 %__N, double* %__X, i32 %__incX, double* %__Y, i32 %__incY) #1 {
entry:
  %__N.addr = alloca i32, align 4
  %__X.addr = alloca double*, align 8
  %__incX.addr = alloca i32, align 4
  %__Y.addr = alloca double*, align 8
  %__incY.addr = alloca i32, align 4
  %sum = alloca double, align 8
  %i = alloca i32, align 4
  store i32 %__N, i32* %__N.addr, align 4
  store double* %__X, double** %__X.addr, align 8
  store i32 %__incX, i32* %__incX.addr, align 4
  store double* %__Y, double** %__Y.addr, align 8
  store i32 %__incY, i32* %__incY.addr, align 4
  store double 0.000000e+00, double* %sum, align 8
  store i32 0, i32* %i, align 4
  br label %for.cond

for.cond:                                         ; preds = %for.inc, %entry
  %0 = load i32, i32* %i, align 4
  %1 = load i32, i32* %__N.addr, align 4
  %cmp = icmp slt i32 %0, %1
  br i1 %cmp, label %for.body, label %for.end

for.body:                                         ; preds = %for.cond
  %2 = load double, double* %sum, align 8
  %3 = load double*, double** %__X.addr, align 8
  %4 = load i32, i32* %i, align 4
  %idxprom = sext i32 %4 to i64
  %arrayidx = getelementptr inbounds double, double* %3, i64 %idxprom
  %5 = load double, double* %arrayidx, align 8
  %6 = load double*, double** %__Y.addr, align 8
  %7 = load i32, i32* %i, align 4
  %idxprom1 = sext i32 %7 to i64
  %arrayidx2 = getelementptr inbounds double, double* %6, i64 %idxprom1
  %8 = load double, double* %arrayidx2, align 8
  %mul = fmul double %5, %8
  %add = fadd double %2, %mul
  store double %add, double* %sum, align 8
  br label %for.inc

for.inc:                                          ; preds = %for.body
  %9 = load i32, i32* %i, align 4
  %inc = add nsw i32 %9, 1
  store i32 %inc, i32* %i, align 4
  br label %for.cond

for.end:                                          ; preds = %for.cond
  %10 = load double, double* %sum, align 8
  ret double %10
}

attributes #0 = { argmemonly nounwind }
attributes #1 = { noinline nounwind optnone uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-frame-pointer-elim"="true" "no-frame-pointer-elim-non-leaf" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="false" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+fxsr,+mmx,+sse,+sse2,+x87" "unsafe-fp-math"="false" "use-soft-float"="false" }

!llvm.ident = !{!0}
!llvm.module.flags = !{!1}

!0 = !{!"clang version 8.0.1 "}
!1 = !{i32 1, !"wchar_size", i32 4}

Now, Enzyme can calculate the derivative of g since the definition of cblas_ddot is available in the LLVM IR source code.

It can currently handle only a limited number of functions, but it can easily be extended - to handle a new function, the name of the function must be added to the list of functions available in the source code of the pass and the definition of the function must be available as a bitcode file.

Merged Pull Request: EnzymeAD/Enzyme#220

Implementation of more functions WIP: https://github.com/wsmoses/Enzyme/tree/reikdas/more-blas-bc

Enzyme AdjointGenerator

Enzyme's AdjointGenerator is taught ("hardcoded") that a particular function's derivative can be calculated using calls to (preferrably) other functions from the same library.

We can currently handle two (important) BLAS functions -

  1. cblas_ddot and cblas_sdot

    cblas_ddot and cblas_sdot are the same operation. ddot is for doubles while sdot is for floats. For the rest of the description, we shall just talk about cblas_ddot, but the same applies for cblas_sdot.

    cblas_ddot computes the dot product of two vectors i.e. it takes two arrays X and Y as input. When we calculate the derivative, we need to calculate the derivative of the function wrt X and Y.

    We can prove (TODO) that the derivative of a function that has a call to cblas_ddot can be calculated with two calls to cblas_daxpy - once for each input array. Since we are already using the CBLAS library, we can freely use functions from that library without incurring significant overhead.

    Simple example -

    For the C source code:

    #include <cblas.h>;
    
    extern double __enzyme_autodiff(double*, double*, double*);
    
    double g(double *restrict m) {
      double n[3] = {4, 5, 6};
      double x = cblas_ddot(3, m, 1, n, 1);
      m[0] = 10;
      double y = x*x;
      return y;
    }
    
    int main() {
      double m[3] = {1, 2, 3};
      double m1[3] = {0.};
      double z = __enzyme_autodiff((double*)g, m, m1);
    }
    

    Important things to note about the above code -

    1. Out of the two input arrays to ddot, only one of them are present in the calling function g. This means that the other input array is inactive, and we do not need to calculate the derivative of cblas_ddot wrt that array.
    2. The array that is "active" is modified after the call to cblas_ddot. This means that if we generate the code to calculate the derivate after that particular instruction, we would be using the modified array instead of the original input array to calculate the derivative. To handle this, we must cache the input values of the array and read it back while generating the code to calculate the derivative.

    From the C code, we get this LLVM IR -

    target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
    target triple = "x86_64-unknown-linux-gnu"
    
    @__const.g.n = private unnamed_addr constant [3 x double] [double 4.000000e+00, double 5.000000e+00, double 6.000000e+00], align 16
    @__const.main.m = private unnamed_addr constant [3 x double] [double 1.000000e+00, double 2.000000e+00, double 3.000000e+00], align 16
    
    define dso_local double @g(double* noalias %m) {
    entry:
      %m.addr = alloca double*, align 8
      %n = alloca [3 x double], align 16
      %x = alloca double, align 8
      %y = alloca double, align 8
      store double* %m, double** %m.addr, align 8
      %0 = bitcast [3 x double]* %n to i8*
      call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([3 x double]* @__const.g.n to i8*), i64 24, i1 false)
      %1 = load double*, double** %m.addr, align 8
      %arraydecay = getelementptr inbounds [3 x double], [3 x double]* %n, i32 0, i32 0
      %call = call double @cblas_ddot(i32 3, double* %1, i32 1, double* %arraydecay, i32 1)
      store double %call, double* %x, align 8
      %2 = load double*, double** %m.addr, align 8
      %arrayidx = getelementptr inbounds double, double* %2, i64 0
      store double 1.000000e+01, double* %arrayidx, align 8
      %3 = load double, double* %x, align 8
      %4 = load double, double* %x, align 8
      %mul = fmul double %3, %4
      store double %mul, double* %y, align 8
      %5 = load double, double* %y, align 8
      ret double %5
    }
    
    declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1)
    
    declare dso_local double @cblas_ddot(i32, double*, i32, double*, i32)
    
    define dso_local i32 @main() {
    entry:
      %m = alloca [3 x double], align 16
      %m1 = alloca [3 x double], align 16
      %z = alloca double, align 8
      %0 = bitcast [3 x double]* %m to i8*
      call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([3 x double]* @__const.main.m to i8*), i64 24, i1 false)
      %1 = bitcast [3 x double]* %m1 to i8*
      call void @llvm.memset.p0i8.i64(i8* align 16 %1, i8 0, i64 24, i1 false)
      %arraydecay = getelementptr inbounds [3 x double], [3 x double]* %m, i32 0, i32 0
      %arraydecay1 = getelementptr inbounds [3 x double], [3 x double]* %m1, i32 0, i32 0
      %call = call double @__enzyme_autodiff(double* bitcast (double (double*)* @g to double*), double* %arraydecay, double* %arraydecay1)
      store double %call, double* %z, align 8
      ret i32 0
    }
    
    declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1)
    
    declare dso_local double @__enzyme_autodiff(double*, double*, double*)
    

    As per the problem we are trying to tackle, the definition of cblas_ddot isn't available.

    We run the Enzyme pass over it to calculate the derivative like - opt < filename.c -load=/path/to/LLVMEnzyme-<version>.so -enzyme -mem2reg -instsimplify -simplifycfg -S

    and generate LLVM IR to calculate the derivative of g -

    define internal void @diffeg(double* noalias %m, double* %"m'", double %differeturn) {
    entry:
      %n = alloca [3 x double], align 16
      %0 = bitcast [3 x double]* %n to i8*
      call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([3 x double]* @__const.g.n to i8*), i64 24, i1 false)
      %arraydecay = getelementptr inbounds [3 x double], [3 x double]* %n, i32 0, i32 0
      %call = call double @cblas_ddot(i32 3, double* nocapture readonly %m, i32 1, double* nocapture readonly %arraydecay, i32 1)
      store double 1.000000e+01, double* %m, align 8
      %m0diffecall = fmul fast double %differeturn, %call
      %m1diffecall = fmul fast double %differeturn, %call
      %1 = fadd fast double %m0diffecall, %m1diffecall
      store double 0.000000e+00, double* %"m'", align 8
      call void @cblas_daxpy(i32 3, double %1, double* %arraydecay, i32 1, double* %"m'", i32 1)
      ret void
    }
    
    declare void @cblas_daxpy(i32, double, double*, i32, double*, i32)
    

    Here, we just needed one call to cblas_daxpy since there was just one active input array.

    Merged Pull Request: EnzymeAD/Enzyme#226

  2. cblas_dgemm and cblas_sgemm

    Similar to the previous set of BLAS functions we talked about, cblas_dgemm and cblas_sgemm are the same operation. dgemm is for doubles while sgemm is for floats. For the rest of the description, we shall just talk about cblas_sgemm, but the same applies for cblas_dgemm.

    cblas_sgemm performs one of the matrix-matrix operations

    C := alpha*op( A )*op( B ) + beta*C,

    where op( X ) is one of

    op( X ) = X or op( X ) = X**T,

    alpha and beta are scalars, and A, B and C are matrices, with op( A ) an m by k matrix, op( B ) a k by n matrix and C an m by n matrix.

    We can prove (TODO) that the derivative of a function that has a call to cblas_sgemm can be calculating using three calls to cblas_dgemm and a call to cblas_sscal.

    Simple example -

    For the C source code:

    #include <cblas.h>;
    
    extern float __enzyme_autodiff(void *, float *, float *, float *, float*, float, float);
    
    void g(float *restrict A, float *C, float alpha, float beta) {
        float B[] = {1011, 1021, 1031,
                     1012, 1022, 1032};
        cblas_sgemm(CblasColMajor, CblasTrans, CblasTrans, 4, 3, 2, alpha, A, 2, B, 3, beta, C, 4);
    }
    
    int main() {
        float A[] = {0.11, 0.21,
                    0.12, 0.22,
                     0.13, 0.23,
                     0.14, 0.24};
        float C[] = {0.00, 0.00, 0.00, 0.00,
                     0.00, 0.00, 0.00, 0.00,
                     0.00, 0.00, 0.00, 0.00};
        float A1[] = {0, 0, 0, 0, 0, 0, 0, 0};
        float C1[] = {1, 1, 1, 1,
                      1, 1, 1, 1,
                      1, 1, 1, 1};
       __enzyme_autodiff((void*)g, A, A1, C, C1, 2.0, 3.0);
    }
    

    Similar to the previous example for cblas_ddot, only one of the input arrays are active.

    From the above C code, we get this LLVM IR -

    target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128"
    target triple = "x86_64-unknown-linux-gnu"
    
    @__const.g.B = private unnamed_addr constant [6 x float] [float 1.011000e+03, float 1.021000e+03, float 1.031000e+03, float 1.012000e+03, float 1.022000e+03, float 1.032000e+03], align 16
    @__const.main.A = private unnamed_addr constant [8 x float] [float 0x3FBC28F5C0000000, float 0x3FCAE147A0000000, float 0x3FBEB851E0000000, float 0x3FCC28F5C0000000, float 0x3FC0A3D700000000, float 0x3FCD70A3E0000000, float 0x3FC1EB8520000000, float 0x3FCEB851E0000000], align 16
    @__const.main.C1 = private unnamed_addr constant [12 x float] [float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00], align 16
    
    define dso_local void @g(float* noalias %A, float* %C, float %alpha, float %beta) {
    entry:
      %A.addr = alloca float*, align 8
      %C.addr = alloca float*, align 8
      %alpha.addr = alloca float, align 4
      %beta.addr = alloca float, align 4
      %B = alloca [6 x float], align 16
      store float* %A, float** %A.addr, align 8
      store float* %C, float** %C.addr, align 8
      store float %alpha, float* %alpha.addr, align 4
      store float %beta, float* %beta.addr, align 4
      %0 = bitcast [6 x float]* %B to i8*
      call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([6 x float]* @__const.g.B to i8*), i64 24, i1 false)
      %1 = load float, float* %alpha.addr, align 4
      %2 = load float*, float** %A.addr, align 8
      %arraydecay = getelementptr inbounds [6 x float], [6 x float]* %B, i32 0, i32 0
      %3 = load float, float* %beta.addr, align 4
      %4 = load float*, float** %C.addr, align 8
      call void @cblas_sgemm(i32 102, i32 112, i32 112, i32 4, i32 3, i32 2, float %1, float* %2, i32 2, float* %arraydecay, i32 3, float %3, float* %4, i32 4)
      ret void
    }
    
    declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1)
    
    declare dso_local void @cblas_sgemm(i32, i32, i32, i32, i32, i32, float, float*, i32, float*, i32, float, float*, i32)
    
    define dso_local i32 @main() {
    entry:
      %A = alloca [8 x float], align 16
      %C = alloca [12 x float], align 16
      %A1 = alloca [8 x float], align 16
      %C1 = alloca [12 x float], align 16
      %0 = bitcast [8 x float]* %A to i8*
      call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([8 x float]* @__const.main.A to i8*), i64 32, i1 false)
      %1 = bitcast [12 x float]* %C to i8*
      call void @llvm.memset.p0i8.i64(i8* align 16 %1, i8 0, i64 48, i1 false)
      %2 = bitcast [8 x float]* %A1 to i8*
      call void @llvm.memset.p0i8.i64(i8* align 16 %2, i8 0, i64 32, i1 false)
      %3 = bitcast [12 x float]* %C1 to i8*
      call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %3, i8* align 16 bitcast ([12 x float]* @__const.main.C1 to i8*), i64 48, i1 false)
      %arraydecay = getelementptr inbounds [8 x float], [8 x float]* %A, i32 0, i32 0
      %arraydecay1 = getelementptr inbounds [8 x float], [8 x float]* %A1, i32 0, i32 0
      %arraydecay2 = getelementptr inbounds [12 x float], [12 x float]* %C, i32 0, i32 0
      %arraydecay3 = getelementptr inbounds [12 x float], [12 x float]* %C1, i32 0, i32 0
      %call = call float @__enzyme_autodiff(i8* bitcast (void (float*, float*, float, float)* @g to i8*), float* %arraydecay, float* %arraydecay1, float* %arraydecay2, float* %arraydecay3, float 2.000000e+00, float 3.000000e+00)
      ret i32 0
    }
    
    declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1)
    
    declare dso_local float @__enzyme_autodiff(i8*, float*, float*, float*, float*, float, float)
    

    As per the problem we are trying to tackle, the definition of cblas_ddot isn't available.

    We run the Enzyme pass over it to calculate the derivative like - opt < filename.c -load=/path/to/LLVMEnzyme-<version>.so -enzyme -mem2reg -instsimplify -simplifycfg -S

    and generate LLVM IR to calculate the derivative of g -

    define internal { float, float } @diffeg(float* noalias %A, float* %"A'", float* %C, float* %"C'", float %alpha, float %beta) {
    entry:
      %B = alloca [6 x float], align 16
      %0 = bitcast [6 x float]* %B to i8*
      call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([6 x float]* @__const.g.B to i8*), i64 24, i1 false)
      %arraydecay = getelementptr inbounds [6 x float], [6 x float]* %B, i32 0, i32 0
      call void @cblas_sgemm(i32 102, i32 112, i32 112, i32 4, i32 3, i32 2, float %alpha, float* nocapture readonly %A, i32 2, float* nocapture readonly %arraydecay, i32 3, float %beta, float* %C, i32 4)
      call void @cblas_sgemm(i32 102, i32 111, i32 111, i32 4, i32 2, i32 3, float %alpha, float* nocapture readonly %"C'", i32 4, float* nocapture readonly %arraydecay, i32 3, float 0x36A0000000000000, float* %"A'", i32 4)
      call void @cblas_sscal(i32 12, float %beta, float* %"C'", i32 1)
      ret { float, float } zeroinitializer
    }
    

    Since we had an inactive input array, we just needed two calls to cblas_sgemm instead of the usually required three calls.

    WIP Pull Request: EnzymeAD/Enzyme#308

Broader idea

We have successfully demonstrated two ways that Enzyme could use to calculate the derivative of commonly used numerical computing libraries that may not be implicitly available in the LLVM IR source code. But there is no clear answer as to which might be the better approach.

Although both approaches will always give the accurate result, performance might vary depending on the context of the function call.

Future work

  • Implement remaining BLAS functions - both with the BCPass as well as in Enzyme's AdjointGenerator. This is the full list of BLAS functions that we want to support in the future - http://www.netlib.org/blas/#_blas_routines
  • Once there are working implementations of "enough" BLAS functions, compare the performance of the two different approaches.

Relevant Links

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment