Skip to content

Instantly share code, notes, and snippets.

@marcrasi
Created May 30, 2019 20:01
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save marcrasi/182c7059212fb75a0fa61b8a1d5ee0ac to your computer and use it in GitHub Desktop.
Save marcrasi/182c7059212fb75a0fa61b8a1d5ee0ac to your computer and use it in GitHub Desktop.
Differentiating class methods

Differentiating Class Methods

marcrasi@

Last updated: 5/21/19

Problem

When c has class type with a method named f, Swift dispatches c.f() at runtime by looking for the concrete implementation of method inside a "vtable" referenced by c.

When Swift is differentiating code, it needs to look up "associated functions" (jvp/vjp) for function applications. e.g.

@differentiable(wrt: (x))
func example(_ c: AClassType, _ x: Float) -> Float {
  return c.f(x)  // The differentiation pass needs to look up the "associated function" for c.f.
}

Since the dispatch for class methods happens at runtime, the differentiation pass can't statically determine what associated function to call.

Solution

Add the associated functions to the vtable, so that the differentiation pass can generate code that looks up the appropriate function in the vtable at runtime!

The rest of this doc describes in detail how exactly to achieve this.

Skip to the last section if you want the tldr of the solution. The rest is just background.

Background -- Witness Tables: The Data

SILWitnessTable

Warning: The SILWitnessTable is a red herring, because it's only used for resilient witness tables, and I'm not going to explain how resilient witness tables work. But the SILWitnessTable is the easiest piece of data to see and understand and manipulate, so I'm going to start off with it for pedagogical purposes.

A SILWitnessTable is an array of witness entries. A witness entry is a pair of requirement and SIL-level entity satisfying the requirement. e.g. a witness entry for a method is a SILDeclRef and a pointer to a SILFunction.

A SILDeclRef is a reference to a SIL-level decl genereated by an AST-level decl. It's implemented as a pointer to the original AST-level decl, plus some metadata determining which SIL decl it actually refers to, for cases when there are multiple SIL decls generated by a single AST-level decl.

Example "witness.swift":

protocol P {
  @differentiable(wrt: (x))
  func req(x: Float) -> Float
}

struct S : P {
  @differentiable(wrt: (x))
  func req(x: Float) -> Float {
    return x + 1
  }
}

"swiftc -emit-sil witness.swift"

sil_witness_table hidden S: P module witness {
  // The entry for the original function.
  // The stuff before the last colon is the SILDeclRef.
  // The stuff after the last colon is the name of the SILFunction satisfying
  // the requirement.
  method #P.req!1: <Self where Self : P> (Self) -> (Float) -> Float : @$s7witness1SVAA1PA2aDP3req1xS2f_tFTW

  // The entry for the JVP.
  method #P.req!1.jvp.1.SU: <Self where Self : P> (Self) -> (Float) -> Float : @AD__$s7witness1SVAA1PA2aDP3req1xS2f_tFTW_jvp_SU

  // The entry for the VJP.
  method #P.req!1.vjp.1.SU: <Self where Self : P> (Self) -> (Float) -> Float : @AD__$s7witness1SVAA1PA2aDP3req1xS2f_tFTW_vjp_SU
}

Protocol Witnesses

Note that the witness table does not reference the SILFunction lowered from the user's definition. The witness table references "protocol witnesses" that wrap the user's SILFunctions. This is because callers expect witness methods to have a certain convention, and this convention usually doesn't match the convention of the user's SILFunctions.

e.g. in the example, the witness method convention for req has Self indirect, but S.req takes S directly:

// S.req(x:)
sil hidden [differentiable source 0 wrt 0 jvp @AD__$s7witness1SV3req1xS2f_tF__jvp_src_0_wrt_0 vjp @AD__$s7witness1SV3req1xS2f_tF__vjp_src_0_wrt_0] @$s7witness1SV3req1xS2f_tF : $@convention(method) (Float, S) -> Float {
// %0                                             // users: %5, %2
// %1                                             // user: %3
bb0(%0 : $Float, %1 : $S):
  debug_value %0 : $Float, let, name "x", argno 1 // id: %2
  debug_value %1 : $S, let, name "self", argno 2  // id: %3
  %4 = float_literal $Builtin.FPIEEE32, 0x3F800000 // 1 // user: %6
  %5 = struct_extract %0 : $Float, #Float._value  // user: %6
  %6 = builtin "fadd_FPIEEE32"(%5 : $Builtin.FPIEEE32, %4 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %7
  %7 = struct $Float (%6 : $Builtin.FPIEEE32)     // user: %8
  return %7 : $Float                              // id: %8
} // end sil function '$s7witness1SV3req1xS2f_tF'

// protocol witness for P.req(x:) in conformance S
sil private [transparent] [thunk] @$s7witness1SVAA1PA2aDP3req1xS2f_tFTW : $@convention(witness_method: P) (Float, @in_guaranteed S) -> Float {
// %0                                             // user: %4
// %1                                             // user: %2
bb0(%0 : $Float, %1 : $*S):
  %2 = load %1 : $*S                              // user: %4
  // function_ref S.req(x:)
  %3 = function_ref @$s7witness1SV3req1xS2f_tF : $@convention(method) (Float, S) -> Float // user: %4
  %4 = apply %3(%0, %2) : $@convention(method) (Float, S) -> Float // user: %5
  return %4 : $Float                              // id: %5
} // end sil function '$s7witness1SVAA1PA2aDP3req1xS2f_tFTW'

WitnessTableLayout

Okay, now throw the useless SILWitnessTable away (unless you're doing resilient witness tables), and look at WitnessTableLayout instead.

The WitnessTableLayout is an ordered array of witness table requirements. e.g. for each method requirement, the WitnessTableLayout has a SILDeclRef.

The order of the requirements comes from their order in the AST.

When IRGen wants to gen a reference to a witness requirement, it queries the WitnessTableLayout to determine the index of the requirement in the table.

IR-Level Witness Table

The IR-level witness table is an ordered list of pointers to pointer witnesses.

The order is the same as the order in WitnessTableLayout.

"swiftc -emit-ir witness.swift"

@"$s7witness1SVAA1PAAWP" = hidden constant [4 x i8*] [

  i8* bitcast (%swift.protocol_conformance_descriptor* @"$s7witness1SVAA1PAAMc" to i8*),

  i8* bitcast (float (float, %T7witness1SV*, %swift.type*, i8**)* @"$s7witness1SVAA1PA2aDP3req1xS2f_tFTW" to i8*),

  i8* bitcast ({ float, i8*, %swift.refcounted* } (float, %T7witness1SV*, %swift.type*, i8**)* @"AD__$s7witness1SVAA1PA2aDP3req1xS2f_tFTW_jvp_SU" to i8*),

  i8* bitcast ({ float, i8*, %swift.refcounted* } (float, %T7witness1SV*, %swift.type*, i8**)* @"AD__$s7witness1SVAA1PA2aDP3req1xS2f_tFTW_vjp_SU" to i8*)

], align 8

I have no idea what the first entry is, but the other three are clearly the pointers to the functions that satisfy the requirements.

Background -- Witness Tables: Generating The Data

There are some SILWitnessVisitors that generate The Data described in the previous section.

Each one of these implements an addMethod(SILDeclRef requirement) method that adds the appropriate things for the given requirement. This method gets called on all the requirements, in the correct order.

Background -- Witness Tables: How Differentiable Requirements Were Added

  1. Modified SILDeclRef so that it could represent autodiff associated functions.
  2. Taught the differentiation pass to differentiate witness_method instructions by replacing them with witness_method instructions containing the SILDeclRef for the appropriate associated function.
  3. Modified SILWitnessVisitor to visit some additional SILDeclRefs whenever it comes across a @differentiable decl.
  4. Updated all the subclasses of SILWitnessVisitor to deal with the new requirements.
    1. WitnessTableLayout: Needed trivial or no changes.
    2. WitnessTableBuilder: Needed trivial or no changes.
    3. SILGenWitnessTable: This was the difficult one.
      1. When generating "Protocol Witnesses", SILGenWitnessTable needs to know the types of the requirement and of the concrete function that satisfies the requirement. So I went through and updated all the places where it calculates types to be aware of the logic necessary to calculate types of "autodiff associated functions".
      2. When generating "Protocol Witnesses", SILGenWitnessTable needs a reference to the concrete underlying function. This concrete function might not exist yet, if the AutoDiff pass is responsible for generating it. So I added an "autodiff_function" instruction that references the possibly-not-yet-generated concrete underlying function. The AutoDiff pass later resolves this reference and replaces it with a reference to a concrete function.

VTables: The Data

Like witness tables, vtables exist in a lot of places. Unlike witness tables, I have not thoroughly investigated all of them. I will describe a few interesting ones.

Type Metadata

This is the record containing the actual VTable that the program uses at runtime when dispatching class method calls. It is documented here. As of 5/21/19, the documentation looks mostly up-to-date to me.

The mangled name for this ends in "CMf". ("CMf" is actually the full type metadata, instead of the type metadata. The difference is that the pointer to type metadata is a pointer to the full type metadata offset by +2.)

Example "example.swift":

class Superclass {
  func hello() {}
}

class Subclass : Superclass {
  override func hello() {}
}

"bin/swiftc -emit-ir example.swift":

@"$s7example8SubclassCMf" = /* type elided */
<{
  void (%T7example8SubclassC*)* @"$s7example8SubclassCfD",
  i8** @"$sBoWV",

  // The Kind field.
  i64 0,

  /* very long entry elided */,

  %swift.opaque* null,
  %%swift.opaque* null,
  i64 1,
  i32 2,
  i32 0,
  i32 16,
  i16 7,
  i16 0,
  i32 112,
  i32 16,

  /* very long entry elided */,

  i8* null,

  // These last two entries are the VTable!! There is the hello() and the init().
  void (%T7example8SubclassC*)* @"$s7example8SubclassC5helloyyF",
  %T7example8SubclassC* (%swift.type*)* @"$s7example8SubclassCACycfC"
}>

Nominal Type Descriptor

I have no idea what this one is for. It appears to contain a bunch of data about the nominal type, plus a VTable, plus some "override metadata". I don't think this is used to dispatch methods at runtime. Maybe it's used for reflection?

Documented here. As of 5/21/19, it says that it's very out of date, and I can confirm that the documentation looks completely different from what I'm seeing in the IR.

SILVTable

"bin/swiftc -emit-sil example.swift":

sil_vtable Superclass {
  #Superclass.hello!1: (Superclass) -> () -> () : @$s7example10SuperclassC5helloyyF     // Superclass.hello()
  #Superclass.init!allocator.1: (Superclass.Type) -> () -> Superclass : @$s7example10SuperclassCACycfC  // Superclass.__allocating_init()
  #Superclass.deinit!deallocator.1: @$s7example10SuperclassCfD  // Superclass.__deallocating_deinit
}

sil_vtable Subclass {
  #Superclass.hello!1: (Superclass) -> () -> () : @$s7example8SubclassC5helloyyF [override]     // Subclass.hello()
  #Superclass.init!allocator.1: (Superclass.Type) -> () -> Superclass : @$s7example8SubclassCACycfC [override]  // Subclass.__allocating_init()
  #Subclass.deinit!deallocator.1: @$s7example8SubclassCfD       // Subclass.__deallocating_deinit
}

VTable Thunks

Here's a nice comment documenting VTable thunks. tldr: Only some vtable entries need thunks. Other vtable entries point at the original function.

VTables: Generating The Data

Just like with witness tables, there is a visitor. Subclasses implement addMethod(SILDeclRef) to accumulate methods into their data.

There is also an addMethodOverride method. It appears to be used to accumulate some override data for the "nominal type descriptor". I don't know what this data is used for.

VTables: How We Can Add AutoDiff Associated Functions

Here's a working prototype PR: apple/swift#24975

It does basically the same thing that we did with witness table (but reusing some of the witness table work):

  1. Teach the differentiation pass to differentiate class_method instructions by replacing them with instructions containing the SILDeclRef for the appropriate associated function.
  2. Modify the SILVTableVisitor to visit the associated functions.
  3. Update the SILVTableVisitor subclasses to deal with any consequences.

While writing the prototype, I discovered one consequence that we need to deal with. VTable entries are not always thunked. When an autodiff associated function VTable entry is not thunked, SILGen looks for the definition of the actual function. But this function does not exist until the Differentiation pass, which happens after SILGen. So things explode. Here are some ways we could deal with this:

  • (terrible solution, which I used for the prototype because it was easiest) Require that the user specify the associated functions for class methods, and use those in the VTable.
  • Generate empty associated functions during SILGen, and have the Differentiation pass fill those in.
  • Always thunk autodiff associated functions in VTables. Make the thunk body contain an autodiff_function instruction that the Differentiation pass later fills in. Note: this is what is done for protocol witness methods.
  • Always generates VJP and JVP thunks in SILGen of @differentiable attribute so that AD pass won’t need to care about them. There may be some modifications necessary in TBDGen. (TF-524)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment