Skip to content

Instantly share code, notes, and snippets.

@masahi
Last active April 11, 2020 04:15
Show Gist options
  • Save masahi/ea860e829e48d53e804b7b8544953be5 to your computer and use it in GitHub Desktop.
Save masahi/ea860e829e48d53e804b7b8544953be5 to your computer and use it in GitHub Desktop.
fn (%v45: Tensor[(16, 3), float32], %v51: Tensor[(16), float32], %v52: Tensor[(16), float32], %v57: Tensor[(16, 4), float32], %v63: Tensor[(16), float32], %v64: Tensor[(16), float32], %v85: Tensor[(4), float32], %v86: Tensor[(4), float32], %states: List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])], %input: Tensor[(5, 2, 3), float32], %v119: Tensor[(16, 3), float32], %v125: Tensor[(16), float32], %v126: Tensor[(16), float32], %v131: Tensor[(16, 4), float32], %v137: Tensor[(16), float32], %v138: Tensor[(16), float32], %v159: Tensor[(4), float32], %v160: Tensor[(4), float32]) -> (Tensor[(?, 2, ?), float32], List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])]) {
%0 = Nil /* ty=List[Tensor[(?, 2, ?), float32]] */;
%1 = Nil /* ty=List[Tensor[(2, 4), float32]] */;
%2 = @nth(%states, 0 /* ty=int32 */) /* ty=(Tensor[(2, 4), float32], Tensor[(2, 4), float32]) */;
%38 = (
let %while_loop: fn (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(5, 2, 3), float32]) -> (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(5, 2, 3), float32]) = fn (%i.4: int32, %outputs.9: List[Tensor[(2, 4), float32]], %state.7: (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), %input.1: Tensor[(5, 2, 3), float32]) -> (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(5, 2, 3), float32]) {
%3 = less(%i.4, 5 /* ty=int32 */) /* ty=bool */;
if (%3) {
%4 = add(%i.4, 1 /* ty=int32 */) /* ty=int32 */;
%5 = take(%input.1, %i.4, axis=0) /* ty=Tensor[(2, 3), float32] */;
%6 = transpose(%v45, axes=[1, 0]) /* ty=Tensor[(3, 16), float32] */;
%7 = transpose(%6, axes=[1, 0]) /* ty=Tensor[(16, 3), float32] */;
%8 = nn.dense(%5, %7, units=None) /* ty=Tensor[(2, 16), float32] */;
%9 = nn.layer_norm(%8, %v51, %v52) /* ty=Tensor[(2, 16), float32] */;
%10 = %state.7.0;
%11 = transpose(%v57, axes=[1, 0]) /* ty=Tensor[(4, 16), float32] */;
%12 = transpose(%11, axes=[1, 0]) /* ty=Tensor[(16, 4), float32] */;
%13 = nn.dense(%10, %12, units=None) /* ty=Tensor[(2, 16), float32] */;
%14 = nn.layer_norm(%13, %v63, %v64) /* ty=Tensor[(2, 16), float32] */;
%15 = add(%9, %14) /* ty=Tensor[(2, 16), float32] */;
%16 = strided_slice(%15, begin=[0, 12], end=[2, 16], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%17 = sigmoid(%16) /* ty=Tensor[(2, 4), float32] */;
%18 = strided_slice(%15, begin=[0, 4], end=[2, 8], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%19 = sigmoid(%18) /* ty=Tensor[(2, 4), float32] */;
%20 = %state.7.1;
%21 = multiply(%19, %20) /* ty=Tensor[(2, 4), float32] */;
%22 = strided_slice(%15, begin=[0, 0], end=[2, 4], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%23 = sigmoid(%22) /* ty=Tensor[(2, 4), float32] */;
%24 = strided_slice(%15, begin=[0, 8], end=[2, 12], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%25 = tanh(%24) /* ty=Tensor[(2, 4), float32] */;
%26 = multiply(%23, %25) /* ty=Tensor[(2, 4), float32] */;
%27 = add(%21, %26) /* ty=Tensor[(2, 4), float32] */;
%28 = nn.layer_norm(%27, %v85, %v86) /* ty=Tensor[(2, 4), float32] */;
%29 = tanh(%28) /* ty=Tensor[(2, 4), float32] */;
%30 = multiply(%17, %29) /* ty=Tensor[(2, 4), float32] */;
%31 = (%30, %28);
%32 = (%30, %31);
%33 = %32.0;
%34 = Nil /* ty=List[Tensor[(2, 4), float32]] */;
%35 = Cons(%33, %34) /* ty=List[Tensor[(2, 4), float32]] */;
%36 = @concat(%outputs.9, %35) /* ty=List[Tensor[(2, 4), float32]] */;
%37 = %32.1;
%while_loop(%4, %36, %37, %input.1) /* ty=(int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(5, 2, 3), float32]) */
} else {
(%i.4, %outputs.9, %state.7, %input.1)
}
};
%while_loop
);
%39 = %38(0 /* ty=int32 */, %1, %2, %input) /* ty=(int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(5, 2, 3), float32]) */;
%40 = %39.1;
%41 = @map(tensor_constructor_float32_2_4(Tensor[(2, 4), float32]), %40) /* ty=List[static_tensor_float32_2_4_t[]] */;
%42 = @tensor_array_stack_float32_2_4(%41) /* ty=static_tensor_float32_?_2_4_t[] */;
%43 = @tensor_get_data_float32_2_4(%42) /* ty=Tensor[(?, 2, ?), float32] */;
%44 = %39.2;
%45 = (%43, %44);
%46 = %45.0;
%47 = Nil /* ty=List[Tensor[(?, 2, ?), float32]] */;
%48 = Cons(%46, %47) /* ty=List[Tensor[(?, 2, ?), float32]] */;
%49 = @concat(%0, %48) /* ty=List[Tensor[(?, 2, 4), float32]] */;
%50 = Nil /* ty=List[Tensor[(2, 4), float32]] */;
%51 = @nth(%states, 1 /* ty=int32 */) /* ty=(Tensor[(2, 4), float32], Tensor[(2, 4), float32]) */;
%89 = (
let %while_loop1: fn (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(5, 2, 3), float32]) -> (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(5, 2, 3), float32]) = fn (%i.1: int32, %outputs.6: List[Tensor[(2, 4), float32]], %state.6: (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), %input.11: Tensor[(5, 2, 3), float32]) -> (int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(5, 2, 3), float32]) {
%52 = less(%i.1, 5 /* ty=int32 */) /* ty=bool */;
if (%52) {
%53 = add(%i.1, 1 /* ty=int32 */) /* ty=int32 */;
%54 = subtract(5 /* ty=int32 */, %i.1) /* ty=int32 */;
%55 = subtract(%54, 1 /* ty=int32 */) /* ty=int32 */;
%56 = take(%input.11, %55, axis=0) /* ty=Tensor[(2, 3), float32] */;
%57 = transpose(%v119, axes=[1, 0]) /* ty=Tensor[(3, 16), float32] */;
%58 = transpose(%57, axes=[1, 0]) /* ty=Tensor[(16, 3), float32] */;
%59 = nn.dense(%56, %58, units=None) /* ty=Tensor[(2, 16), float32] */;
%60 = nn.layer_norm(%59, %v125, %v126) /* ty=Tensor[(2, 16), float32] */;
%61 = %state.6.0;
%62 = transpose(%v131, axes=[1, 0]) /* ty=Tensor[(4, 16), float32] */;
%63 = transpose(%62, axes=[1, 0]) /* ty=Tensor[(16, 4), float32] */;
%64 = nn.dense(%61, %63, units=None) /* ty=Tensor[(2, 16), float32] */;
%65 = nn.layer_norm(%64, %v137, %v138) /* ty=Tensor[(2, 16), float32] */;
%66 = add(%60, %65) /* ty=Tensor[(2, 16), float32] */;
%67 = strided_slice(%66, begin=[0, 12], end=[2, 16], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%68 = sigmoid(%67) /* ty=Tensor[(2, 4), float32] */;
%69 = strided_slice(%66, begin=[0, 4], end=[2, 8], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%70 = sigmoid(%69) /* ty=Tensor[(2, 4), float32] */;
%71 = %state.6.1;
%72 = multiply(%70, %71) /* ty=Tensor[(2, 4), float32] */;
%73 = strided_slice(%66, begin=[0, 0], end=[2, 4], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%74 = sigmoid(%73) /* ty=Tensor[(2, 4), float32] */;
%75 = strided_slice(%66, begin=[0, 8], end=[2, 12], strides=[1, 1]) /* ty=Tensor[(2, 4), float32] */;
%76 = tanh(%75) /* ty=Tensor[(2, 4), float32] */;
%77 = multiply(%74, %76) /* ty=Tensor[(2, 4), float32] */;
%78 = add(%72, %77) /* ty=Tensor[(2, 4), float32] */;
%79 = nn.layer_norm(%78, %v159, %v160) /* ty=Tensor[(2, 4), float32] */;
%80 = tanh(%79) /* ty=Tensor[(2, 4), float32] */;
%81 = multiply(%68, %80) /* ty=Tensor[(2, 4), float32] */;
%82 = (%81, %79);
%83 = (%81, %82);
%84 = %83.0;
%85 = Nil /* ty=List[Tensor[(2, 4), float32]] */;
%86 = Cons(%84, %85) /* ty=List[Tensor[(2, 4), float32]] */;
%87 = @concat(%86, %outputs.6) /* ty=List[Tensor[(2, 4), float32]] */;
%88 = %83.1;
%while_loop1(%53, %87, %88, %input.11) /* ty=(int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(5, 2, 3), float32]) */
} else {
(%i.1, %outputs.6, %state.6, %input.11)
}
};
%while_loop1
);
%90 = %89(0 /* ty=int32 */, %50, %51, %input) /* ty=(int32, List[Tensor[(2, 4), float32]], (Tensor[(2, 4), float32], Tensor[(2, 4), float32]), Tensor[(5, 2, 3), float32]) */;
%91 = %90.1;
%92 = @map(tensor_constructor_float32_2_4(Tensor[(2, 4), float32]), %91) /* ty=List[static_tensor_float32_2_4_t[]] */;
%93 = @tensor_array_stack_float32_2_4(%92) /* ty=static_tensor_float32_?_2_4_t[] */;
%94 = @tensor_get_data_float32_2_4(%93) /* ty=Tensor[(?, 2, ?), float32] */;
%95 = %90.2;
%96 = (%94, %95);
%97 = %96.0;
%98 = Nil /* ty=List[Tensor[(?, 2, ?), float32]] */;
%99 = Cons(%97, %98) /* ty=List[Tensor[(?, 2, 4), float32]] */;
%100 = @concat(%49, %99) /* ty=List[Tensor[(?, 2, ?), float32]] */;
%101 = @map(tensor_constructor_float32_?_2_?(Tensor[(?, 2, ?), float32]), %100) /* ty=List[static_tensor_float32_?_2_?_t[]] */;
%102 = @tensor_array_concat_last_float32_?_2_?(%101) /* ty=static_tensor_float32_?_2_?_t[] */;
%103 = @tensor_get_data_float32_?_2_?(%102) /* ty=Tensor[(?, 2, ?), float32] */;
%104 = Nil /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%105 = %45.1;
%106 = Nil /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%107 = Cons(%105, %106) /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%108 = @concat(%104, %107) /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%109 = %96.1;
%110 = Nil /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%111 = Cons(%109, %110) /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
%112 = @concat(%108, %111) /* ty=List[(Tensor[(2, 4), float32], Tensor[(2, 4), float32])] */;
(%103, %112)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment