Last active
March 11, 2024 06:31
-
-
Save kato-megumi/4c85f695e425ea50db72e38e63bef9de to your computer and use it in GitHub Desktop.
convert pth to hlsl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import math | |
import re | |
import argparse | |
import os | |
parser = argparse.ArgumentParser() | |
parser.add_argument("pth", help="Path to the input pth file") | |
parser.add_argument("out", help="Path to the output folder") | |
parser.add_argument("name", help="Name of hlsl file") | |
args = parser.parse_args() | |
params = torch.load(args.pth)['params'] | |
out_hlsl = os.path.join(args.out, f"{args.name}.hlsl") | |
a = len([x for x in params.keys() if "conv_mid" in x]) / 2 + 1 | |
c = len(params["conv_head.bias"]) | |
b = params["conv_tail.weight"].shape[1] / 2 / c | |
num_conv = int(a) | |
num_feat = c | |
block_stack = int(b) | |
num_text = num_feat // 4 | |
tex_define = """ | |
//!TEXTURE | |
//!WIDTH INPUT_WIDTH | |
//!HEIGHT INPUT_HEIGHT | |
//!FORMAT R16G16B16A16_FLOAT | |
Texture2D conv2d_{a}_{b}; | |
""" | |
pass_define = ''' | |
//!PASS {pass_no} | |
//!DESC Conv-{pass_no} | |
//!IN {in_tex} | |
//!OUT {out_tex} | |
//!BLOCK_SIZE 8 | |
//!NUM_THREADS 64 | |
void Pass{pass_no}(uint2 blockStart, uint3 threadId) {{ | |
uint2 gxy = Rmp8x8(threadId.x) + blockStart; | |
uint2 inputSize = GetInputSize(); | |
if (gxy.x >= inputSize.x || gxy.y >= inputSize.y) {{ | |
return; | |
}} | |
float2 inputPt = GetInputPt(); | |
float2 pos = (gxy + 0.5f) * inputPt; | |
{calculation} | |
}} | |
''' | |
pass_1st = ''' | |
//!PASS 1 | |
//!DESC First Pass | |
//!IN INPUT | |
//!OUT {out_tex} | |
//!BLOCK_SIZE 16 | |
//!NUM_THREADS 64 | |
void Pass1(uint2 blockStart, uint3 threadId) {{ | |
uint2 gxy = (Rmp8x8(threadId.x) << 1) + blockStart; | |
uint2 inputSize = GetInputSize(); | |
if (gxy.x >= inputSize.x || gxy.y >= inputSize.y) {{ | |
return; | |
}} | |
float2 inputPt = GetInputPt(); | |
uint i, j; | |
float3 src[4][4]; | |
[unroll] | |
for (i = 0; i <= 2; i += 2) {{ | |
[unroll] | |
for (j = 0; j <= 2; j += 2) {{ | |
float2 tpos = (gxy + uint2(i, j)) * inputPt; | |
const float4 sr = INPUT.GatherRed(sam, tpos); | |
const float4 sg = INPUT.GatherGreen(sam, tpos); | |
const float4 sb = INPUT.GatherBlue(sam, tpos); | |
// w z | |
// x y | |
src[i][j] = float3(sr.w, sg.w, sb.w); | |
src[i][j + 1] = float3(sr.x, sg.x, sb.x); | |
src[i + 1][j] = float3(sr.z, sg.z, sb.z); | |
src[i + 1][j + 1] = float3(sr.y, sg.y, sb.y); | |
}} | |
}} | |
[unroll] | |
for (i = 1; i <= 2; ++i) {{ | |
[unroll] | |
for (j = 1; j <= 2; ++j) {{ | |
uint2 destPos = gxy + uint2(i - 1, j - 1); | |
if (i != 1 || j != 1) {{ | |
if (destPos.x >= inputSize.x || destPos.y >= inputSize.y) {{ | |
continue; | |
}} | |
}} | |
{calculation} | |
}} | |
}} | |
}} | |
''' | |
get_pixel = ''' | |
float4 a{b} = conv2d_{a}_{b}.SampleLevel(sam, pos + float2(-inputPt.x, -inputPt.y), 0); | |
float4 b{b} = conv2d_{a}_{b}.SampleLevel(sam, pos + float2(-inputPt.x, 0), 0); | |
float4 c{b} = conv2d_{a}_{b}.SampleLevel(sam, pos + float2(-inputPt.x, inputPt.y), 0); | |
float4 d{b} = conv2d_{a}_{b}.SampleLevel(sam, pos + float2(0, -inputPt.y), 0); | |
float4 e{b} = conv2d_{a}_{b}.SampleLevel(sam, pos, 0); | |
float4 f{b} = conv2d_{a}_{b}.SampleLevel(sam, pos + float2(0, inputPt.y), 0); | |
float4 g{b} = conv2d_{a}_{b}.SampleLevel(sam, pos + float2(inputPt.x, -inputPt.y), 0); | |
float4 h{b} = conv2d_{a}_{b}.SampleLevel(sam, pos + float2(inputPt.x, 0), 0); | |
float4 i{b} = conv2d_{a}_{b}.SampleLevel(sam, pos + float2(inputPt.x, inputPt.y), 0); | |
float4 na{b} = max(-a{b}, 0); | |
float4 nb{b} = max(-b{b}, 0); | |
float4 nc{b} = max(-c{b}, 0); | |
float4 nd{b} = max(-d{b}, 0); | |
float4 ne{b} = max(-e{b}, 0); | |
float4 nf{b} = max(-f{b}, 0); | |
float4 ng{b} = max(-g{b}, 0); | |
float4 nh{b} = max(-h{b}, 0); | |
float4 ni{b} = max(-i{b}, 0); | |
a{b} = max(a{b}, 0); | |
b{b} = max(b{b}, 0); | |
c{b} = max(c{b}, 0); | |
d{b} = max(d{b}, 0); | |
e{b} = max(e{b}, 0); | |
f{b} = max(f{b}, 0); | |
g{b} = max(g{b}, 0); | |
h{b} = max(h{b}, 0); | |
i{b} = max(i{b}, 0); | |
''' | |
def cal_weight(j, b, define=False, first=False, negative=False): | |
return f''' | |
{"float4 " if define else ""}target{j} {"" if first else "+"}= mul({'n' if negative else ""}a{b}, float4x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target{j} += mul({'n' if negative else ""}b{b}, float4x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target{j} += mul({'n' if negative else ""}c{b}, float4x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target{j} += mul({'n' if negative else ""}d{b}, float4x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target{j} += mul({'n' if negative else ""}e{b}, float4x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target{j} += mul({'n' if negative else ""}f{b}, float4x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target{j} += mul({'n' if negative else ""}g{b}, float4x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target{j} += mul({'n' if negative else ""}h{b}, float4x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target{j} += mul({'n' if negative else ""}i{b}, float4x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
''' | |
cal_bias = ''' | |
target += float4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000); | |
conv2d_{a}_{b}[gxy] = target; | |
''' | |
cal_1st = ''' | |
{define}target = mul(src[i - 1][j - 1], float3x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target += mul(src[i - 1][j], float3x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target += mul(src[i - 1][j + 1], float3x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target += mul(src[i][j - 1], float3x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target += mul(src[i][j], float3x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target += mul(src[i][j + 1], float3x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target += mul(src[i + 1][j - 1], float3x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target += mul(src[i + 1][j], float3x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target += mul(src[i + 1][j + 1], float3x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target += float4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000); | |
conv2d_1_{b}[destPos] = target; | |
''' | |
last_pass = ''' | |
//!PASS {pass_no} | |
//!DESC Conv-{pass_no} Depth-to-Space | |
//!IN INPUT, {in_tex} | |
//!OUT OUTPUT | |
//!BLOCK_SIZE 16 | |
//!NUM_THREADS 64 | |
void Pass{pass_no}(uint2 blockStart, uint3 threadId) {{ | |
uint2 gxy = (Rmp8x8(threadId.x) << 1) + blockStart; | |
const uint2 outputSize = GetOutputSize(); | |
if (gxy.x >= outputSize.x || gxy.y >= outputSize.y) {{ | |
return; | |
}} | |
float2 inputPt = GetInputPt(); | |
float2 pos = ((gxy >> 1) + 0.5f) * inputPt; | |
{calculation} | |
float2 outputPt = GetOutputPt(); | |
pos -= 0.5f * outputPt; | |
OUTPUT[gxy] = float4(float3(target1.x, target2.x, target3.x) + INPUT.SampleLevel(sam1, pos, 0).rgb, 1); | |
++gxy.x; | |
pos.x += outputPt.x; | |
OUTPUT[gxy] = float4(float3(target1.y, target2.y, target3.y) + INPUT.SampleLevel(sam1, pos, 0).rgb, 1); | |
++gxy.y; | |
pos.y += outputPt.y; | |
OUTPUT[gxy] = float4(float3(target1.w, target2.w, target3.w) + INPUT.SampleLevel(sam1, pos, 0).rgb, 1); | |
--gxy.x; | |
pos.x -= outputPt.x; | |
OUTPUT[gxy] = float4(float3(target1.z, target2.z, target3.z) + INPUT.SampleLevel(sam1, pos, 0).rgb, 1); | |
}} | |
''' | |
def lastpass_weight(j, b, define=False, first=False, negative=False): | |
return f''' | |
{"float4 " if define else ""}target{j} {"" if first else "+"}= mul({'n' if negative else ""}g{b}, float4x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000));''' | |
############################################################### | |
hlsl = '''//!MAGPIE EFFECT | |
//!VERSION 4 | |
//!SORT_NAME anime4k | |
//!TEXTURE | |
Texture2D INPUT; | |
//!TEXTURE | |
//!WIDTH INPUT_WIDTH * 2 | |
//!HEIGHT INPUT_HEIGHT * 2 | |
Texture2D OUTPUT; | |
''' | |
for i in range(num_conv): | |
for j in range(num_text): | |
hlsl+=tex_define.format(a=i+1,b=j) | |
hlsl+=''' | |
//!SAMPLER | |
//!FILTER POINT | |
SamplerState sam; | |
//!SAMPLER | |
//!FILTER LINEAR | |
SamplerState sam1; | |
''' | |
### First pass | |
out_tex = ", ".join([f"conv2d_1_{b}" for b in range(num_text)]) | |
calculation = '' | |
for b in range(num_text): | |
calculation += cal_1st.format(b=b, define="float4 " if b==0 else "") | |
hlsl+=pass_1st.format(out_tex=out_tex, calculation=calculation) | |
### Middle pass | |
for pass_no in range(2, num_conv+1): | |
in_tex = ", ".join([f"conv2d_{pass_no-1}_{b}" for b in range(num_text)]) | |
out_tex = ", ".join([f"conv2d_{pass_no}_{b}" for b in range(num_text)]) | |
calculation = '' | |
for b in range(num_text): | |
calculation += get_pixel.format(a=pass_no-1, b=b) | |
for i in range(num_text): | |
for b in range(num_text): | |
calculation += cal_weight(j="", b=b, define=(i==0 and b==0), first=(b==0)) | |
for b in range(num_text): | |
calculation += cal_weight(j="", b=b, negative=True) | |
calculation += cal_bias.format(a=pass_no, b=i) | |
hlsl+= pass_define.format(pass_no=pass_no,in_tex=in_tex,out_tex=out_tex,calculation=calculation) | |
#### Last pass | |
pass_no += 1 | |
last_in_texture = [f"conv2d_{pass_no-block_stack+i}_{b}" for i in range(block_stack) for b in range(num_text)] | |
in_tex = ", ".join(last_in_texture) | |
calculation = '' | |
for i, tex in enumerate(last_in_texture): | |
calculation += f"\n\tfloat4 g{i} = {tex}.SampleLevel(sam, pos, 0);" | |
calculation += "\n" | |
for i, tex in enumerate(last_in_texture): | |
calculation += f"\n\tfloat4 ng{i} = max(-g{i}, 0);" | |
calculation += "\n" | |
for i, tex in enumerate(last_in_texture): | |
calculation += f"\n\tg{i} = max(g{i}, 0);" | |
calculation += "\n" | |
for i in range(3): | |
for j in range(block_stack): | |
for k in range(num_text): | |
b = j * num_text + k | |
calculation += lastpass_weight(j=i+1, b=b, define=(b==0), first=(b==0)) | |
for k in range(num_text): | |
b = j * num_text + k | |
calculation += lastpass_weight(j=i+1, b=b, negative=True) | |
calculation += f''' | |
target{i+1} += float4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000); | |
''' | |
hlsl+= last_pass.format(pass_no=pass_no,in_tex=in_tex,calculation=calculation) | |
def convert(weight, bias, data, doswap=False): | |
swap = [0,2,1,3] | |
out_chan, in_chan, width, height = weight.shape | |
for to in range(math.ceil(out_chan/4)): | |
for ti in range(math.ceil(in_chan/4)): | |
for w in range(width): | |
for h in range(height): | |
for i in range(min(4, in_chan)): | |
for o in range(min(4, out_chan)): | |
o = swap[o] if doswap else o | |
# data.append(float(weight[to*4+o, ti*4+i, w, h])) | |
data.append(float(weight[to*4+o, ti*4+i, h, w])) | |
for o in range(min(4, out_chan)): | |
o = swap[o] if doswap else o | |
data.append(float(bias.data[to*4+o])) | |
# model = model['params_ema'] | |
# model = model['params'] | |
num_conv = len([i for i in params.keys() if ".bias" in i]) | |
data = [] | |
if True: | |
layers = [i[:-7] for i in params.keys() if ".weight" in i] | |
data = [] | |
for i in layers: | |
# convert(params[i+".weight"], params[i+".bias"], data, doswap= "tail" in i) | |
convert(params[i+".weight"], params[i+".bias"], data) | |
data_iter = iter(data) | |
def replace_match(match): | |
return str(next(data_iter)) | |
pattern = r'-?\d+(\.\d{2,})(e-?\d+)?' | |
hlsl = re.sub(pattern, replace_match, hlsl) | |
with open(out_hlsl,"w") as f: | |
f.write(hlsl) | |
try: | |
next(data_iter) | |
except StopIteration: | |
print("done") | |
else: | |
print("Fail") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import math | |
import re | |
model = torch.load("R:/animejanai_sharp_SUC.pth", map_location=torch.device('cuda')) | |
out_hlsl = "R:/animejanai_sharp_SUC.hlsl" | |
num_feat = 24 | |
num_text = int(num_feat/4) | |
num_conv = 8 | |
tex_define = """ | |
//!TEXTURE | |
//!WIDTH INPUT_WIDTH | |
//!HEIGHT INPUT_HEIGHT | |
//!FORMAT R16G16B16A16_FLOAT | |
Texture2D conv2d_{a}_{b}; | |
""" | |
pass_define = ''' | |
//!PASS {pass_no} | |
//!DESC Conv-{pass_no} | |
//!IN {in_tex} | |
//!OUT {out_tex} | |
//!BLOCK_SIZE 8 | |
//!NUM_THREADS 64 | |
void Pass{pass_no}(uint2 blockStart, uint3 threadId) {{ | |
uint2 gxy = Rmp8x8(threadId.x) + blockStart; | |
uint2 inputSize = GetInputSize(); | |
if (gxy.x >= inputSize.x || gxy.y >= inputSize.y) {{ | |
return; | |
}} | |
float2 inputPt = GetInputPt(); | |
float2 pos = (gxy + 0.5f) * inputPt; | |
{calculation} | |
}} | |
''' | |
pass_1st = ''' | |
//!PASS 1 | |
//!DESC First Pass | |
//!IN INPUT | |
//!OUT {out_tex} | |
//!BLOCK_SIZE 16 | |
//!NUM_THREADS 64 | |
void Pass1(uint2 blockStart, uint3 threadId) {{ | |
uint2 gxy = (Rmp8x8(threadId.x) << 1) + blockStart; | |
uint2 inputSize = GetInputSize(); | |
if (gxy.x >= inputSize.x || gxy.y >= inputSize.y) {{ | |
return; | |
}} | |
float2 inputPt = GetInputPt(); | |
uint i, j; | |
float3 src[4][4]; | |
[unroll] | |
for (i = 0; i <= 2; i += 2) {{ | |
[unroll] | |
for (j = 0; j <= 2; j += 2) {{ | |
float2 tpos = (gxy + uint2(i, j)) * inputPt; | |
const float4 sr = INPUT.GatherRed(sam, tpos); | |
const float4 sg = INPUT.GatherGreen(sam, tpos); | |
const float4 sb = INPUT.GatherBlue(sam, tpos); | |
// w z | |
// x y | |
src[i][j] = float3(sr.w, sg.w, sb.w); | |
src[i][j + 1] = float3(sr.x, sg.x, sb.x); | |
src[i + 1][j] = float3(sr.z, sg.z, sb.z); | |
src[i + 1][j + 1] = float3(sr.y, sg.y, sb.y); | |
}} | |
}} | |
[unroll] | |
for (i = 1; i <= 2; ++i) {{ | |
[unroll] | |
for (j = 1; j <= 2; ++j) {{ | |
uint2 destPos = gxy + uint2(i - 1, j - 1); | |
if (i != 1 || j != 1) {{ | |
if (destPos.x >= inputSize.x || destPos.y >= inputSize.y) {{ | |
continue; | |
}} | |
}} | |
{calculation} | |
}} | |
}} | |
}} | |
''' | |
get_pixel = ''' | |
float4 a{b} = conv2d_{a}_{b}.SampleLevel(sam, pos + float2(-inputPt.x, -inputPt.y), 0); | |
float4 b{b} = conv2d_{a}_{b}.SampleLevel(sam, pos + float2(-inputPt.x, 0), 0); | |
float4 c{b} = conv2d_{a}_{b}.SampleLevel(sam, pos + float2(-inputPt.x, inputPt.y), 0); | |
float4 d{b} = conv2d_{a}_{b}.SampleLevel(sam, pos + float2(0, -inputPt.y), 0); | |
float4 e{b} = conv2d_{a}_{b}.SampleLevel(sam, pos, 0); | |
float4 f{b} = conv2d_{a}_{b}.SampleLevel(sam, pos + float2(0, inputPt.y), 0); | |
float4 g{b} = conv2d_{a}_{b}.SampleLevel(sam, pos + float2(inputPt.x, -inputPt.y), 0); | |
float4 h{b} = conv2d_{a}_{b}.SampleLevel(sam, pos + float2(inputPt.x, 0), 0); | |
float4 i{b} = conv2d_{a}_{b}.SampleLevel(sam, pos + float2(inputPt.x, inputPt.y), 0); | |
''' | |
def cal_weight(j, b, define=False, first=False): | |
return f''' | |
{"float4 " if define else ""}target{j} {"" if first else "+"}= mul(a{b}, float4x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target{j} += mul(b{b}, float4x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target{j} += mul(c{b}, float4x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target{j} += mul(d{b}, float4x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target{j} += mul(e{b}, float4x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target{j} += mul(f{b}, float4x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target{j} += mul(g{b}, float4x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target{j} += mul(h{b}, float4x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target{j} += mul(i{b}, float4x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
''' | |
cal_bias_prelu = ''' | |
target += float4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000); | |
target = max(target, 0) + float4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000) * min(target, 0); | |
conv2d_{a}_{b}[gxy] = target; | |
''' | |
cal_1st = ''' | |
{define}target = mul(src[i - 1][j - 1], float3x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target += mul(src[i - 1][j], float3x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target += mul(src[i - 1][j + 1], float3x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target += mul(src[i][j - 1], float3x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target += mul(src[i][j], float3x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target += mul(src[i][j + 1], float3x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target += mul(src[i + 1][j - 1], float3x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target += mul(src[i + 1][j], float3x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target += mul(src[i + 1][j + 1], float3x4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000)); | |
target += float4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000); | |
target = max(target, 0) + float4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000) * min(target, 0); | |
conv2d_1_{b}[destPos] = target; | |
''' | |
last_pass = ''' | |
//!PASS {pass_no} | |
//!DESC Conv-{pass_no} Depth-to-Space | |
//!IN INPUT, {in_tex} | |
//!OUT OUTPUT | |
//!BLOCK_SIZE 16 | |
//!NUM_THREADS 64 | |
void Pass{pass_no}(uint2 blockStart, uint3 threadId) {{ | |
uint2 gxy = (Rmp8x8(threadId.x) << 1) + blockStart; | |
const uint2 outputSize = GetOutputSize(); | |
if (gxy.x >= outputSize.x || gxy.y >= outputSize.y) {{ | |
return; | |
}} | |
float2 inputPt = GetInputPt(); | |
float2 pos = ((gxy >> 1) + 0.5f) * inputPt; | |
{calculation} | |
float2 outputPt = GetOutputPt(); | |
pos -= 0.5f * outputPt; | |
OUTPUT[gxy] = float4(float3(target1.x, target2.x, target3.x) + INPUT.SampleLevel(sam, pos, 0).rgb, 1); | |
++gxy.x; | |
pos.x += outputPt.x; | |
OUTPUT[gxy] = float4(float3(target1.y, target2.y, target3.y) + INPUT.SampleLevel(sam, pos, 0).rgb, 1); | |
++gxy.y; | |
pos.y += outputPt.y; | |
OUTPUT[gxy] = float4(float3(target1.w, target2.w, target3.w) + INPUT.SampleLevel(sam, pos, 0).rgb, 1); | |
--gxy.x; | |
pos.x -= outputPt.x; | |
OUTPUT[gxy] = float4(float3(target1.z, target2.z, target3.z) + INPUT.SampleLevel(sam, pos, 0).rgb, 1); | |
}} | |
''' | |
############################################################### | |
hlsl = '''//!MAGPIE EFFECT | |
//!VERSION 4 | |
//!SORT_NAME compact | |
//!TEXTURE | |
Texture2D INPUT; | |
//!TEXTURE | |
//!WIDTH INPUT_WIDTH * 2 | |
//!HEIGHT INPUT_HEIGHT * 2 | |
Texture2D OUTPUT; | |
''' | |
for i in range(num_conv+1): | |
for j in range(num_text): | |
hlsl+=tex_define.format(a=i+1,b=j) | |
hlsl+=''' | |
//!SAMPLER | |
//!FILTER POINT | |
SamplerState sam; | |
''' | |
### First pass | |
out_tex = ", ".join([f"conv2d_1_{b}" for b in range(num_text)]) | |
calculation = '' | |
for b in range(num_text): | |
calculation += cal_1st.format(b=b, define="float4 " if b==0 else "") | |
hlsl+=pass_1st.format(out_tex=out_tex, calculation=calculation) | |
### Middle pass | |
for pass_no in range(2, num_conv+2): | |
in_tex = ", ".join([f"conv2d_{pass_no-1}_{b}" for b in range(num_text)]) | |
out_tex = ", ".join([f"conv2d_{pass_no}_{b}" for b in range(num_text)]) | |
calculation = '' | |
for b in range(num_text): | |
calculation += get_pixel.format(a=pass_no-1, b=b) | |
for i in range(num_text): | |
for b in range(num_text): | |
calculation += cal_weight(j="", b=b, define=(i==0 and b==0), first=(b==0)) | |
calculation += cal_bias_prelu.format(a=pass_no, b=i) | |
hlsl+= pass_define.format(pass_no=pass_no,in_tex=in_tex,out_tex=out_tex,calculation=calculation) | |
#### Last pass | |
pass_no += 1 | |
in_tex = ", ".join([f"conv2d_{pass_no-1}_{b}" for b in range(num_text)]) | |
calculation = '' | |
for b in range(num_text): | |
calculation += get_pixel.format(a=pass_no-1, b=b) | |
for i in range(3): | |
for b in range(num_text): | |
calculation += cal_weight(j=i+1, b=b, define=(b==0), first=(b==0)) | |
calculation += f''' | |
target{i+1} += float4(0.0000000000, 0.0000000000, 0.0000000000, 0.0000000000); | |
''' | |
hlsl+= last_pass.format(pass_no=pass_no,in_tex=in_tex,calculation=calculation) | |
def convert(weight, bias, data, prelu=None): | |
out_chan, in_chan, width, height = weight.shape | |
for to in range(math.ceil(out_chan/4)): | |
for ti in range(math.ceil(in_chan/4)): | |
for w in range(width): | |
for h in range(height): | |
for i in range(min(4, in_chan)): | |
for o in range(min(4, out_chan)): | |
data.append(float(weight[to*4+o, ti*4+i, h, w])) | |
for o in range(min(4, out_chan)): | |
data.append(float(bias.data[to*4+o])) | |
for o in range(min(4, out_chan)): | |
if prelu is not None: | |
data.append(float(prelu.data[to*4+o])) | |
# model = model['params_ema'] | |
model = model['params'] | |
num_conv = len([i for i in model.keys() if ".bias" in i]) | |
data = [] | |
for i in range(num_conv): | |
if i == num_conv-1: | |
convert(model[f"body.{i*2}.weight"], model[f"body.{i*2}.bias"], data) | |
else: | |
convert(model[f"body.{i*2}.weight"], model[f"body.{i*2}.bias"], data, prelu=model[f"body.{2*i+1}.weight"]) | |
data_iter = iter(data) | |
def replace_match(match): | |
return str(next(data_iter)) | |
pattern = r'-?\d+(\.\d{2,})(e-?\d+)?' | |
new_text = re.sub(pattern, replace_match, hlsl) | |
with open(out_hlsl,"w") as f: | |
f.write(new_text) | |
try: | |
next(data_iter) | |
except StopIteration: | |
print("done") | |
else: | |
print("Fail") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment