marcrasi@
Last updated: 5/21/19
When c
has class type with a method named f
, Swift dispatches c.f()
at runtime by looking for the concrete implementation of method
inside a "vtable" referenced by c
.
When Swift is differentiating code, it needs to look up "associated functions" (jvp/vjp) for function applications. e.g.
@differentiable(wrt: (x))
func example(_ c: AClassType, _ x: Float) -> Float {
return c.f(x) // The differentiation pass needs to look up the "associated function" for c.f.
}
Since the dispatch for class methods happens at runtime, the differentiation pass can't statically determine what associated function to call.
Add the associated functions to the vtable, so that the differentiation pass can generate code that looks up the appropriate function in the vtable at runtime!
The rest of this doc describes in detail how exactly to achieve this.
Skip to the last section if you want the tldr of the solution. The rest is just background.
Warning: The SILWitnessTable is a red herring, because it's only used for resilient witness tables, and I'm not going to explain how resilient witness tables work. But the SILWitnessTable is the easiest piece of data to see and understand and manipulate, so I'm going to start off with it for pedagogical purposes.
A SILWitnessTable is an array of witness entries. A witness entry is a pair of requirement and SIL-level entity satisfying the requirement. e.g. a witness entry for a method is a SILDeclRef and a pointer to a SILFunction.
A SILDeclRef is a reference to a SIL-level decl genereated by an AST-level decl. It's implemented as a pointer to the original AST-level decl, plus some metadata determining which SIL decl it actually refers to, for cases when there are multiple SIL decls generated by a single AST-level decl.
Example "witness.swift":
protocol P {
@differentiable(wrt: (x))
func req(x: Float) -> Float
}
struct S : P {
@differentiable(wrt: (x))
func req(x: Float) -> Float {
return x + 1
}
}
"swiftc -emit-sil witness.swift"
sil_witness_table hidden S: P module witness {
// The entry for the original function.
// The stuff before the last colon is the SILDeclRef.
// The stuff after the last colon is the name of the SILFunction satisfying
// the requirement.
method #P.req!1: <Self where Self : P> (Self) -> (Float) -> Float : @$s7witness1SVAA1PA2aDP3req1xS2f_tFTW
// The entry for the JVP.
method #P.req!1.jvp.1.SU: <Self where Self : P> (Self) -> (Float) -> Float : @AD__$s7witness1SVAA1PA2aDP3req1xS2f_tFTW_jvp_SU
// The entry for the VJP.
method #P.req!1.vjp.1.SU: <Self where Self : P> (Self) -> (Float) -> Float : @AD__$s7witness1SVAA1PA2aDP3req1xS2f_tFTW_vjp_SU
}
Note that the witness table does not reference the SILFunction lowered from the user's definition. The witness table references "protocol witnesses" that wrap the user's SILFunctions. This is because callers expect witness methods to have a certain convention, and this convention usually doesn't match the convention of the user's SILFunctions.
e.g. in the example, the witness method convention for req
has Self
indirect, but S.req
takes S
directly:
// S.req(x:)
sil hidden [differentiable source 0 wrt 0 jvp @AD__$s7witness1SV3req1xS2f_tF__jvp_src_0_wrt_0 vjp @AD__$s7witness1SV3req1xS2f_tF__vjp_src_0_wrt_0] @$s7witness1SV3req1xS2f_tF : $@convention(method) (Float, S) -> Float {
// %0 // users: %5, %2
// %1 // user: %3
bb0(%0 : $Float, %1 : $S):
debug_value %0 : $Float, let, name "x", argno 1 // id: %2
debug_value %1 : $S, let, name "self", argno 2 // id: %3
%4 = float_literal $Builtin.FPIEEE32, 0x3F800000 // 1 // user: %6
%5 = struct_extract %0 : $Float, #Float._value // user: %6
%6 = builtin "fadd_FPIEEE32"(%5 : $Builtin.FPIEEE32, %4 : $Builtin.FPIEEE32) : $Builtin.FPIEEE32 // user: %7
%7 = struct $Float (%6 : $Builtin.FPIEEE32) // user: %8
return %7 : $Float // id: %8
} // end sil function '$s7witness1SV3req1xS2f_tF'
// protocol witness for P.req(x:) in conformance S
sil private [transparent] [thunk] @$s7witness1SVAA1PA2aDP3req1xS2f_tFTW : $@convention(witness_method: P) (Float, @in_guaranteed S) -> Float {
// %0 // user: %4
// %1 // user: %2
bb0(%0 : $Float, %1 : $*S):
%2 = load %1 : $*S // user: %4
// function_ref S.req(x:)
%3 = function_ref @$s7witness1SV3req1xS2f_tF : $@convention(method) (Float, S) -> Float // user: %4
%4 = apply %3(%0, %2) : $@convention(method) (Float, S) -> Float // user: %5
return %4 : $Float // id: %5
} // end sil function '$s7witness1SVAA1PA2aDP3req1xS2f_tFTW'
Okay, now throw the useless SILWitnessTable away (unless you're doing resilient witness tables), and look at WitnessTableLayout instead.
The WitnessTableLayout is an ordered array of witness table requirements. e.g. for each method requirement, the WitnessTableLayout has a SILDeclRef.
The order of the requirements comes from their order in the AST.
When IRGen wants to gen a reference to a witness requirement, it queries the WitnessTableLayout to determine the index of the requirement in the table.
The IR-level witness table is an ordered list of pointers to pointer witnesses.
The order is the same as the order in WitnessTableLayout.
"swiftc -emit-ir witness.swift"
@"$s7witness1SVAA1PAAWP" = hidden constant [4 x i8*] [
i8* bitcast (%swift.protocol_conformance_descriptor* @"$s7witness1SVAA1PAAMc" to i8*),
i8* bitcast (float (float, %T7witness1SV*, %swift.type*, i8**)* @"$s7witness1SVAA1PA2aDP3req1xS2f_tFTW" to i8*),
i8* bitcast ({ float, i8*, %swift.refcounted* } (float, %T7witness1SV*, %swift.type*, i8**)* @"AD__$s7witness1SVAA1PA2aDP3req1xS2f_tFTW_jvp_SU" to i8*),
i8* bitcast ({ float, i8*, %swift.refcounted* } (float, %T7witness1SV*, %swift.type*, i8**)* @"AD__$s7witness1SVAA1PA2aDP3req1xS2f_tFTW_vjp_SU" to i8*)
], align 8
I have no idea what the first entry is, but the other three are clearly the pointers to the functions that satisfy the requirements.
There are some SILWitnessVisitors that generate The Data described in the previous section.
- SILGenWitnessTable (actually, its subclasses, I think): Generates the "Protocol Witnesses" and "SILWitnessTable"
- WitnessTableLayout: Generates itself
- WitnessTableBuilder: Generates the IR-Level Witness Table
Each one of these implements an addMethod(SILDeclRef requirement)
method that adds the appropriate things for the given requirement
. This method gets called on all the requirements, in the correct order.
- Modified SILDeclRef so that it could represent autodiff associated functions.
- Taught the differentiation pass to differentiate witness_method instructions by replacing them with witness_method instructions containing the SILDeclRef for the appropriate associated function.
- Modified SILWitnessVisitor to visit some additional SILDeclRefs whenever it comes across a @differentiable decl.
- Updated all the subclasses of SILWitnessVisitor to deal with the new requirements.
- WitnessTableLayout: Needed trivial or no changes.
- WitnessTableBuilder: Needed trivial or no changes.
- SILGenWitnessTable: This was the difficult one.
- When generating "Protocol Witnesses", SILGenWitnessTable needs to know the types of the requirement and of the concrete function that satisfies the requirement. So I went through and updated all the places where it calculates types to be aware of the logic necessary to calculate types of "autodiff associated functions".
- When generating "Protocol Witnesses", SILGenWitnessTable needs a reference to the concrete underlying function. This concrete function might not exist yet, if the AutoDiff pass is responsible for generating it. So I added an "autodiff_function" instruction that references the possibly-not-yet-generated concrete underlying function. The AutoDiff pass later resolves this reference and replaces it with a reference to a concrete function.
Like witness tables, vtables exist in a lot of places. Unlike witness tables, I have not thoroughly investigated all of them. I will describe a few interesting ones.
This is the record containing the actual VTable that the program uses at runtime when dispatching class method calls. It is documented here. As of 5/21/19, the documentation looks mostly up-to-date to me.
The mangled name for this ends in "CMf". ("CMf" is actually the full type metadata, instead of the type metadata. The difference is that the pointer to type metadata is a pointer to the full type metadata offset by +2.)
Example "example.swift":
class Superclass {
func hello() {}
}
class Subclass : Superclass {
override func hello() {}
}
"bin/swiftc -emit-ir example.swift":
@"$s7example8SubclassCMf" = /* type elided */
<{
void (%T7example8SubclassC*)* @"$s7example8SubclassCfD",
i8** @"$sBoWV",
// The Kind field.
i64 0,
/* very long entry elided */,
%swift.opaque* null,
%%swift.opaque* null,
i64 1,
i32 2,
i32 0,
i32 16,
i16 7,
i16 0,
i32 112,
i32 16,
/* very long entry elided */,
i8* null,
// These last two entries are the VTable!! There is the hello() and the init().
void (%T7example8SubclassC*)* @"$s7example8SubclassC5helloyyF",
%T7example8SubclassC* (%swift.type*)* @"$s7example8SubclassCACycfC"
}>
I have no idea what this one is for. It appears to contain a bunch of data about the nominal type, plus a VTable, plus some "override metadata". I don't think this is used to dispatch methods at runtime. Maybe it's used for reflection?
Documented here. As of 5/21/19, it says that it's very out of date, and I can confirm that the documentation looks completely different from what I'm seeing in the IR.
"bin/swiftc -emit-sil example.swift":
sil_vtable Superclass {
#Superclass.hello!1: (Superclass) -> () -> () : @$s7example10SuperclassC5helloyyF // Superclass.hello()
#Superclass.init!allocator.1: (Superclass.Type) -> () -> Superclass : @$s7example10SuperclassCACycfC // Superclass.__allocating_init()
#Superclass.deinit!deallocator.1: @$s7example10SuperclassCfD // Superclass.__deallocating_deinit
}
sil_vtable Subclass {
#Superclass.hello!1: (Superclass) -> () -> () : @$s7example8SubclassC5helloyyF [override] // Subclass.hello()
#Superclass.init!allocator.1: (Superclass.Type) -> () -> Superclass : @$s7example8SubclassCACycfC [override] // Subclass.__allocating_init()
#Subclass.deinit!deallocator.1: @$s7example8SubclassCfD // Subclass.__deallocating_deinit
}
Here's a nice comment documenting VTable thunks. tldr: Only some vtable entries need thunks. Other vtable entries point at the original function.
Just like with witness tables, there is a visitor. Subclasses implement addMethod(SILDeclRef)
to accumulate methods into their data.
There is also an addMethodOverride
method. It appears to be used to accumulate some override data for the "nominal type descriptor". I don't know what this data is used for.
Here's a working prototype PR: apple/swift#24975
It does basically the same thing that we did with witness table (but reusing some of the witness table work):
- Teach the differentiation pass to differentiate
class_method
instructions by replacing them with instructions containing the SILDeclRef for the appropriate associated function. - Modify the SILVTableVisitor to visit the associated functions.
- Update the SILVTableVisitor subclasses to deal with any consequences.
While writing the prototype, I discovered one consequence that we need to deal with. VTable entries are not always thunked. When an autodiff associated function VTable entry is not thunked, SILGen looks for the definition of the actual function. But this function does not exist until the Differentiation pass, which happens after SILGen. So things explode. Here are some ways we could deal with this:
- (terrible solution, which I used for the prototype because it was easiest) Require that the user specify the associated functions for class methods, and use those in the VTable.
- Generate empty associated functions during SILGen, and have the Differentiation pass fill those in.
- Always thunk autodiff associated functions in VTables. Make the thunk body contain an
autodiff_function
instruction that the Differentiation pass later fills in. Note: this is what is done for protocol witness methods. - Always generates VJP and JVP thunks in SILGen of @differentiable attribute so that AD pass won’t need to care about them. There may be some modifications necessary in TBDGen. (TF-524)