Skip to content

Instantly share code, notes, and snippets.

@shoaibkamil
Created October 13, 2017 20:47
Show Gist options
  • Save shoaibkamil/a23a6c9804157467552cfac32c8aa087 to your computer and use it in GitHub Desktop.
Save shoaibkamil/a23a6c9804157467552cfac32c8aa087 to your computer and use it in GitHub Desktop.
Metal generated code
#include <metal_stdlib>
using namespace metal;
namespace {
constexpr float float_from_bits(unsigned int x) {return as_type<float>(x);}
constexpr float nan_f32() { return as_type<float>(0x7fc00000); }
constexpr float neg_inf_f32() { return float_from_bits(0xff800000); }
constexpr float inf_f32() { return float_from_bits(0x7f800000); }
float fast_inverse_f32(float x) { return 1.0f / x; }
#define sqrt_f32 sqrt
#define sin_f32 sin
#define cos_f32 cos
#define exp_f32 exp
#define log_f32 log
#define abs_f32 fabs
#define floor_f32 floor
#define ceil_f32 ceil
#define round_f32 round
#define trunc_f32 trunc
#define pow_f32 pow
#define asin_f32 asin
#define acos_f32 acos
#define tan_f32 tan
#define atan_f32 atan
#define atan2_f32 atan2
#define sinh_f32 sinh
#define asinh_f32 asinh
#define cosh_f32 cosh
#define acosh_f32 acosh
#define tanh_f32 tanh
#define atanh_f32 atanh
#define fast_inverse_sqrt_f32 rsqrt
#define halide_gpu_thread_barrier() \
(threadgroup_barrier(mem_flags::mem_threadgroup), 0)
}
#define __address_space___shared threadgroup
// Address spaces for kernel_output_s0_y_y___block_id_y
#define __address_space__dst device
#define __address_space__mask device
#define __address_space__output device
#define __address_space__src device
struct kernel_output_s0_y_y___block_id_y_args {
int _dst_min_0;
int _dst_min_1;
int _dst_stride_1;
int _mask_min_0;
int _mask_min_1;
int _mask_stride_1;
int _output_extent_0;
int _output_extent_1;
int _output_min_0;
int _output_min_1;
int _output_stride_1;
int _src_min_0;
int _src_min_1;
int _src_stride_1;
};
kernel void kernel_output_s0_y_y___block_id_y(
uint3 tgroup_index [[ threadgroup_position_in_grid ]],
uint3 tid_in_tgroup [[ thread_position_in_threadgroup ]],
const device kernel_output_s0_y_y___block_id_y_args *_scalar_args [[ buffer(0) ]],
__address_space__dst const uchar *_dst [[ buffer(1) ]],
__address_space__mask const uchar *_mask [[ buffer(2) ]],
__address_space__output uchar *_output [[ buffer(3) ]],
__address_space__src const uchar *_src [[ buffer(4) ]],
threadgroup int16_t* __shared [[ threadgroup(0) ]])
{
int _dst_min_0 = _scalar_args->_dst_min_0;
int _dst_min_1 = _scalar_args->_dst_min_1;
int _dst_stride_1 = _scalar_args->_dst_stride_1;
int _mask_min_0 = _scalar_args->_mask_min_0;
int _mask_min_1 = _scalar_args->_mask_min_1;
int _mask_stride_1 = _scalar_args->_mask_stride_1;
int _output_extent_0 = _scalar_args->_output_extent_0;
int _output_extent_1 = _scalar_args->_output_extent_1;
int _output_min_0 = _scalar_args->_output_min_0;
int _output_min_1 = _scalar_args->_output_min_1;
int _output_stride_1 = _scalar_args->_output_stride_1;
int _src_min_0 = _scalar_args->_src_min_0;
int _src_min_1 = _scalar_args->_src_min_1;
int _src_stride_1 = _scalar_args->_src_stride_1;
int _output_s0_y_y___block_id_y = (int)tgroup_index.y;
int _output_s0_x_x___block_id_x = (int)tgroup_index.x;
int ___thread_id_y = (int)tid_in_tgroup.y;
int ___thread_id_x = (int)tid_in_tgroup.x;
int _0 = _output_s0_y_y___block_id_y * 8;
int _1 = _0 + _output_min_1;
int _2 = _output_min_1 + _output_extent_1;
int _3 = _2 + -8;
int _4 = min(_1, _3);
int _5 = _output_extent_0 >> 3;
int _6 = max(_5, 0);
bool _7 = _output_s0_x_x___block_id_x < _6;
if (_7)
{
int _8 = _output_s0_x_x___block_id_x * 8;
int _9 = _8 + _output_min_0;
int _10 = _9 + ___thread_id_x;
int _11 = _10 * 4;
int _12 = _4 + ___thread_id_y;
int _13 = _12 * _src_stride_1;
int _14 = _11 + _13;
int _15 = _src_min_0 * 4;
int _16 = _src_min_1 * _src_stride_1;
int _17 = _15 + _16;
int _18 = _14 - _17;
uchar _19 = _src[_18];
int _20 = _12 * _dst_stride_1;
int _21 = _11 + _20;
int _22 = _dst_min_0 * 4;
int _23 = _dst_min_1 * _dst_stride_1;
int _24 = _22 + _23;
int _25 = _21 - _24;
uchar _26 = _dst[_25];
int _27 = _12 * _mask_stride_1;
int _28 = _10 + _27;
int _29 = _mask_min_1 * _mask_stride_1;
int _30 = _mask_min_0 + _29;
int _31 = _28 - _30;
uchar _32 = _mask[_31];
uchar _33 = (uchar)(128);
bool _34 = _33 < _19;
uchar _35 = _19 - _33;
short _36 = short(_35);
short _37 = _36 >> 6;
short _38 = _36 + _37;
uchar _39 = uchar(_38);
bool _40 = _26 < _39;
bool _41 = _34 && _40;
uchar _42 = (uchar)select(_26, _32, _41);
int _43 = _12 * _output_stride_1;
int _44 = _11 + _43;
int _45 = _output_min_0 * 4;
int _46 = _output_min_1 * _output_stride_1;
int _47 = _45 + _46;
int _48 = _44 - _47;
_output[_48] = _42;
int _49 = _output_s0_x_x___block_id_x * 8;
int _50 = _49 + _output_min_0;
int _51 = _50 + ___thread_id_x;
int _52 = _51 * 4;
int _53 = _4 + ___thread_id_y;
int _54 = _53 * _src_stride_1;
int _55 = _52 + _54;
int _56 = _src_min_0 * 4;
int _57 = _src_min_1 * _src_stride_1;
int _58 = _56 + _57;
int _59 = _55 - _58;
int _60 = _59 + 1;
uchar _61 = _src[_60];
int _62 = _53 * _dst_stride_1;
int _63 = _52 + _62;
int _64 = _dst_min_0 * 4;
int _65 = _dst_min_1 * _dst_stride_1;
int _66 = _64 + _65;
int _67 = _63 - _66;
int _68 = _67 + 1;
uchar _69 = _dst[_68];
int _70 = _53 * _mask_stride_1;
int _71 = _51 + _70;
int _72 = _mask_min_1 * _mask_stride_1;
int _73 = _mask_min_0 + _72;
int _74 = _71 - _73;
uchar _75 = _mask[_74];
uchar _76 = (uchar)(128);
bool _77 = _76 < _61;
uchar _78 = _61 - _76;
short _79 = short(_78);
short _80 = _79 >> 6;
short _81 = _79 + _80;
uchar _82 = uchar(_81);
bool _83 = _69 < _82;
bool _84 = _77 && _83;
uchar _85 = (uchar)select(_69, _75, _84);
int _86 = _53 * _output_stride_1;
int _87 = _52 + _86;
int _88 = _output_min_0 * 4;
int _89 = _output_min_1 * _output_stride_1;
int _90 = _88 + _89;
int _91 = _87 - _90;
int _92 = _91 + 1;
_output[_92] = _85;
int _93 = _output_s0_x_x___block_id_x * 8;
int _94 = _93 + _output_min_0;
int _95 = _94 + ___thread_id_x;
int _96 = _95 * 4;
int _97 = _4 + ___thread_id_y;
int _98 = _97 * _src_stride_1;
int _99 = _96 + _98;
int _100 = _src_min_0 * 4;
int _101 = _src_min_1 * _src_stride_1;
int _102 = _100 + _101;
int _103 = _99 - _102;
int _104 = _103 + 2;
uchar _105 = _src[_104];
int _106 = _97 * _dst_stride_1;
int _107 = _96 + _106;
int _108 = _dst_min_0 * 4;
int _109 = _dst_min_1 * _dst_stride_1;
int _110 = _108 + _109;
int _111 = _107 - _110;
int _112 = _111 + 2;
uchar _113 = _dst[_112];
int _114 = _97 * _mask_stride_1;
int _115 = _95 + _114;
int _116 = _mask_min_1 * _mask_stride_1;
int _117 = _mask_min_0 + _116;
int _118 = _115 - _117;
uchar _119 = _mask[_118];
uchar _120 = (uchar)(128);
bool _121 = _120 < _105;
uchar _122 = _105 - _120;
short _123 = short(_122);
short _124 = _123 >> 6;
short _125 = _123 + _124;
uchar _126 = uchar(_125);
bool _127 = _113 < _126;
bool _128 = _121 && _127;
uchar _129 = (uchar)select(_113, _119, _128);
int _130 = _97 * _output_stride_1;
int _131 = _96 + _130;
int _132 = _output_min_0 * 4;
int _133 = _output_min_1 * _output_stride_1;
int _134 = _132 + _133;
int _135 = _131 - _134;
int _136 = _135 + 2;
_output[_136] = _129;
} // if _7
else
{
int _137 = _output_min_0 + _output_extent_0;
int _138 = _137 + ___thread_id_x;
int _139 = _138 * 4;
int _140 = _4 + ___thread_id_y;
int _141 = _140 * _src_stride_1;
int _142 = _139 + _141;
int _143 = _src_min_0 * 4;
int _144 = _src_min_1 * _src_stride_1;
int _145 = _143 + _144;
int _146 = _142 - _145;
int _147 = _146 + -32;
uchar _148 = _src[_147];
int _149 = _140 * _dst_stride_1;
int _150 = _139 + _149;
int _151 = _dst_min_0 * 4;
int _152 = _dst_min_1 * _dst_stride_1;
int _153 = _151 + _152;
int _154 = _150 - _153;
int _155 = _154 + -32;
uchar _156 = _dst[_155];
int _157 = _140 * _mask_stride_1;
int _158 = _138 + _157;
int _159 = _mask_min_1 * _mask_stride_1;
int _160 = _mask_min_0 + _159;
int _161 = _158 - _160;
int _162 = _161 + -8;
uchar _163 = _mask[_162];
uchar _164 = (uchar)(128);
bool _165 = _164 < _148;
uchar _166 = _148 - _164;
short _167 = short(_166);
short _168 = _167 >> 6;
short _169 = _167 + _168;
uchar _170 = uchar(_169);
bool _171 = _156 < _170;
bool _172 = _165 && _171;
uchar _173 = (uchar)select(_156, _163, _172);
int _174 = _140 * _output_stride_1;
int _175 = _139 + _174;
int _176 = _output_min_0 * 4;
int _177 = _output_min_1 * _output_stride_1;
int _178 = _176 + _177;
int _179 = _175 - _178;
int _180 = _179 + -32;
_output[_180] = _173;
int _181 = _output_min_0 + _output_extent_0;
int _182 = _181 + ___thread_id_x;
int _183 = _182 * 4;
int _184 = _4 + ___thread_id_y;
int _185 = _184 * _src_stride_1;
int _186 = _183 + _185;
int _187 = _src_min_0 * 4;
int _188 = _src_min_1 * _src_stride_1;
int _189 = _187 + _188;
int _190 = _186 - _189;
int _191 = _190 + -31;
uchar _192 = _src[_191];
int _193 = _184 * _dst_stride_1;
int _194 = _183 + _193;
int _195 = _dst_min_0 * 4;
int _196 = _dst_min_1 * _dst_stride_1;
int _197 = _195 + _196;
int _198 = _194 - _197;
int _199 = _198 + -31;
uchar _200 = _dst[_199];
int _201 = _184 * _mask_stride_1;
int _202 = _182 + _201;
int _203 = _mask_min_1 * _mask_stride_1;
int _204 = _mask_min_0 + _203;
int _205 = _202 - _204;
int _206 = _205 + -8;
uchar _207 = _mask[_206];
uchar _208 = (uchar)(128);
bool _209 = _208 < _192;
uchar _210 = _192 - _208;
short _211 = short(_210);
short _212 = _211 >> 6;
short _213 = _211 + _212;
uchar _214 = uchar(_213);
bool _215 = _200 < _214;
bool _216 = _209 && _215;
uchar _217 = (uchar)select(_200, _207, _216);
int _218 = _184 * _output_stride_1;
int _219 = _183 + _218;
int _220 = _output_min_0 * 4;
int _221 = _output_min_1 * _output_stride_1;
int _222 = _220 + _221;
int _223 = _219 - _222;
int _224 = _223 + -31;
_output[_224] = _217;
int _225 = _output_min_0 + _output_extent_0;
int _226 = _225 + ___thread_id_x;
int _227 = _226 * 4;
int _228 = _4 + ___thread_id_y;
int _229 = _228 * _src_stride_1;
int _230 = _227 + _229;
int _231 = _src_min_0 * 4;
int _232 = _src_min_1 * _src_stride_1;
int _233 = _231 + _232;
int _234 = _230 - _233;
int _235 = _234 + -30;
uchar _236 = _src[_235];
int _237 = _228 * _dst_stride_1;
int _238 = _227 + _237;
int _239 = _dst_min_0 * 4;
int _240 = _dst_min_1 * _dst_stride_1;
int _241 = _239 + _240;
int _242 = _238 - _241;
int _243 = _242 + -30;
uchar _244 = _dst[_243];
int _245 = _228 * _mask_stride_1;
int _246 = _226 + _245;
int _247 = _mask_min_1 * _mask_stride_1;
int _248 = _mask_min_0 + _247;
int _249 = _246 - _248;
int _250 = _249 + -8;
uchar _251 = _mask[_250];
uchar _252 = (uchar)(128);
bool _253 = _252 < _236;
uchar _254 = _236 - _252;
short _255 = short(_254);
short _256 = _255 >> 6;
short _257 = _255 + _256;
uchar _258 = uchar(_257);
bool _259 = _244 < _258;
bool _260 = _253 && _259;
uchar _261 = (uchar)select(_244, _251, _260);
int _262 = _228 * _output_stride_1;
int _263 = _227 + _262;
int _264 = _output_min_0 * 4;
int _265 = _output_min_1 * _output_stride_1;
int _266 = _264 + _265;
int _267 = _263 - _266;
int _268 = _267 + -30;
_output[_268] = _261;
} // if _7 else
} // kernel kernel_output_s0_y_y___block_id_y
#undef __address_space__dst
#undef __address_space__mask
#undef __address_space__output
#undef __address_space__src
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment