Skip to content

Instantly share code, notes, and snippets.

@n-yoda
Created December 8, 2018 15:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save n-yoda/4e86c43531121232dfddfd70b5e56a25 to your computer and use it in GitHub Desktop.
Save n-yoda/4e86c43531121232dfddfd70b5e56a25 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import tensorflow as tf
class UVPQ:
def __init__(self, w, h, p, q):
indices_x = tf.range(w, dtype=tf.int32)
indices_y = tf.range(h, dtype=tf.int32)
x, y = tf.meshgrid(indices_x, indices_y)
c0_5 = tf.constant(0.5, tf.float32)
w_f = tf.cast(w, tf.float32)
h_f = tf.cast(h, tf.float32)
u = (tf.cast(x, tf.float32) + c0_5) / w_f
v = (tf.cast(y, tf.float32) + c0_5) / h_f
u_2d = tf.reshape(u, [-1, 1])
v_2d = tf.reshape(v, [-1, 1])
uv = tf.concat([u_2d, v_2d], axis=1)
uvp = tf.pad(uv, [[0, 0], [0, 1]], constant_values=p)
self.uvpq = tf.pad(uvp, [[0, 0], [0, 1]], constant_values=q)
class Train:
def __init__(self, path, p, q):
file = tf.read_file(path)
raw = tf.image.decode_image(file, 4)
self.image = tf.cast(raw, tf.float32) / 255.0
self.linear = tf.reshape(self.image, [-1, 4])
self.h = tf.shape(raw)[0]
self.w = tf.shape(raw)[1]
self._uvpq = UVPQ(self.w, self.h, p, q)
self.uvpq = self._uvpq.uvpq
parser = argparse.ArgumentParser(description='NN texture')
parser.add_argument('--render', default=False, action='store_true')
parser.add_argument('-p', type=float, default=0.0)
parser.add_argument('-q', type=float, default=0.0)
parser.add_argument('-W', type=int, default=64)
parser.add_argument('-H', type=int, default=64)
args = parser.parse_args()
# For rendering
p = tf.placeholder_with_default(
tf.constant(args.p, tf.float32), [])
q = tf.placeholder_with_default(
tf.constant(args.q, tf.float32), [])
uvpq = UVPQ(args.W, args.H, p, q)
# For learning
train0 = Train('blue.png', 1.0, 1.0)
train1 = Train('higan.png', 1.0, -1.0)
train2 = Train('rose.png', -1.0, -1.0)
train3 = Train('hasu.png', -1.0, 1.0)
trains = [train0, train1, train2, train3]
trains_uvpq = tf.concat([t.uvpq for t in trains], axis=0)
trains_linear = tf.concat([t.linear for t in trains], axis=0)
# Model
input = uvpq.uvpq if args.render else trains_uvpq
w0 = tf.Variable(tf.random_normal([4,4]))
b0 = tf.Variable (tf.random_normal([4]))
out0 = tf.tanh(tf.matmul(input, w0) + b0)
w1 = tf.Variable(tf.random_normal([4,4]))
b1 = tf.Variable (tf.random_normal([4]))
out1 = tf.tanh(tf.matmul(input, w1) + b1)
w2 = tf.Variable(tf.random_normal([4,4]))
b2 = tf.Variable (tf.random_normal([4]))
out2 = tf.matmul(out0, w2) + b2
w3 = tf.Variable(tf.random_normal([4,4]))
b3 = tf.Variable (tf.random_normal([4]))
out3 = tf.matmul(out1, w3) + b3
out_linear = tf.tanh(out2 + out3) * 0.5 + 0.5
out_linear8 = tf.cast(out_linear * 255, tf.uint8)
if not args.render:
trainable = [tf.reshape(x, [-1]) for x in tf.trainable_variables()]
norm = tf.norm(tf.concat(trainable, axis=0))
error = tf.losses.mean_squared_error(trains_linear, out_linear)
optimizer = tf.train.AdamOptimizer()
minimize = optimizer.minimize(error + norm * 0.001)
saver = tf.train.Saver(tf.global_variables())
sess = tf.Session()
try:
saver.restore(sess, './ckpt/test')
print('restored')
except:
sess.run(tf.global_variables_initializer())
saver.save(sess, './ckpt/test')
print('created')
if args.render:
out_image = tf.reshape(out_linear, [args.H, args.W, -1])
else:
for i in range(10000):
_, e, n = sess.run([minimize, error, norm])
if i % 100 == 0:
print('error:', e, 'norm:', n)
saver.save(sess, './ckpt/test')
print('saved')
start = 0
for i in range(len(trains)):
train = trains[i]
w, h = sess.run([train.w, train.h])
out_slice = tf.slice(out_linear8, [start, 0], [w * h, -1])
out_slice_reshaped = tf.reshape(out_slice, [h, w, -1])
start += w * h
sess.run(
tf.write_file(
'predict' + str(i) + '.png',
tf.image.encode_png(out_slice_reshaped)))
def print_mat(tensor, name):
result = 'half4x4 ' + name + ' = half4x4('
val = sess.run(tensor)
for i in range(4):
for j in range(4):
if i != 0 or j != 0:
result += ', '
result += str(val[i, j])
result += ');'
print(result)
return result
def print_vec(tensor, name):
result = 'half4 ' + name + ' = half4('
val = sess.run(tensor)
for i in range(4):
if i != 0:
result += ', '
result += str(val[i])
result += ');'
print(result)
return result
print_mat(w0, 'w0')
print_vec(b0, 'b0')
print_mat(w1, 'w1')
print_vec(b1, 'b1')
print_mat(w2, 'w2')
print_vec(b2, 'b2')
print_mat(w3, 'w3')
print_vec(b3, 'b3')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment