Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save FeepingCreature/8a8ace4e3b3c69607d250906fce72e9b to your computer and use it in GitHub Desktop.
Save FeepingCreature/8a8ace4e3b3c69607d250906fce72e9b to your computer and use it in GitHub Desktop.
diff --git a/src/generate_unconditional_samples.py b/src/generate_unconditional_samples.py
index d9e2131..366c411 100755
--- a/src/generate_unconditional_samples.py
+++ b/src/generate_unconditional_samples.py
@@ -9,7 +9,7 @@ import tensorflow as tf
import model, sample, encoder
def sample_model(
- model_name='117M',
+ model_name='1558N',
seed=None,
nsamples=0,
batch_size=1,
@@ -69,7 +69,7 @@ def sample_model(
out = sess.run(output)
for i in range(batch_size):
generated += batch_size
- text = enc.decode(out[i])
+ text = enc.decode(out[i]).encode('utf-8')
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
diff --git a/src/interactive_conditional_samples.py b/src/interactive_conditional_samples.py
index c1650bb..193f5c8 100755
--- a/src/interactive_conditional_samples.py
+++ b/src/interactive_conditional_samples.py
@@ -9,7 +9,7 @@ import tensorflow as tf
import model, sample, encoder
def interact_model(
- model_name='117M',
+ model_name='1558M',
seed=None,
nsamples=1,
batch_size=1,
@@ -64,7 +64,7 @@ def interact_model(
)
saver = tf.train.Saver()
- ckpt = tf.train.latest_checkpoint(os.path.join('models', model_name))
+ ckpt = tf.train.latest_checkpoint('checkpoint/run1')
saver.restore(sess, ckpt)
while True:
@@ -80,7 +80,7 @@ def interact_model(
})[:, len(context_tokens):]
for i in range(batch_size):
generated += 1
- text = enc.decode(out[i])
+ text = enc.decode(out[i]).encode('utf-8')
print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
print(text)
print("=" * 80)
diff --git a/src/memory_saving_gradients.py b/src/memory_saving_gradients.py
index 659691f..9b46e89 100644
--- a/src/memory_saving_gradients.py
+++ b/src/memory_saving_gradients.py
@@ -108,6 +108,7 @@ def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs):
ts_all = [t for t in ts_all if 'dropout' not in t.name]
# DV: FP16_FIX - need to add 'Cast' layer here to make it work for FP16
ts_all = [t for t in ts_all if 'Cast' not in t.name]
+ ts_all = [t for t in ts_all if 'SparseSoftmaxCrossEntropyWithLogits' not in t.name]
# filter out all tensors that are inputs of the backward graph
with util.capture_ops() as bwd_ops:
@@ -120,11 +121,12 @@ def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs):
# try two slightly different ways of getting bottlenecks tensors
# to checkpoint
- for ts in [ts_filtered, ts_all]:
+ # for ts in [ts_filtered, ts_all]:
+ for ts in [ts_filtered]:
# get all bottlenecks in the graph
bottleneck_ts = []
- for t in ts:
+ for i, t in enumerate(ts):
b = set(ge.get_backward_walk_ops(t.op, inclusive=True, within_ops=fwd_ops))
f = set(ge.get_forward_walk_ops(t.op, inclusive=False, within_ops=fwd_ops))
# check that there are not shortcuts
@@ -133,7 +135,12 @@ def gradients(ys, xs, grad_ys=None, checkpoints='collection', **kwargs):
if not set(b_inp).intersection(f_inp) and len(b_inp)+len(f_inp) >= len(ts_all):
bottleneck_ts.append(t) # we have a bottleneck!
else:
- debug_print("Rejected bottleneck candidate and ops %s", [t] + list(set(ts_all) - set(b_inp) - set(f_inp)))
+ debug_print("%s/%s: Rejected bottleneck candidate and ops %s; %s found of %s",
+ i, len(ts),
+ [t] + list(set(ts_all) - set(b_inp) - set(f_inp)),
+ len(bottleneck_ts),
+ np.sqrt(len(ts_filtered)) * (i / len(ts))
+ )
# success? or try again without filtering?
if len(bottleneck_ts) >= np.sqrt(len(ts_filtered)): # yes, enough bottlenecks found!
diff --git a/src/model.py b/src/model.py
index 4e942d8..71092bc 100644
--- a/src/model.py
+++ b/src/model.py
@@ -124,10 +124,10 @@ def block(x, scope, *, past, hparams):
with tf.variable_scope(scope):
nx = x.shape[-1].value
a, present = attn(norm(x, 'ln_1'), 'attn', nx, past=past, hparams=hparams)
- x = x + a
+ x = x1 = x + a
m = mlp(norm(x, 'ln_2'), 'mlp', nx*4, hparams=hparams)
x = x + m
- return x, present
+ return x, present, x1
def past_shape(*, hparams, batch_size=None, sequence=None):
return [batch_size, hparams.n_layer, 2, hparams.n_head, sequence, hparams.n_embd // hparams.n_head]
@@ -161,9 +161,9 @@ def model(hparams, X, past=None, scope='model', reuse=tf.AUTO_REUSE):
pasts = tf.unstack(past, axis=1) if past is not None else [None] * hparams.n_layer
assert len(pasts) == hparams.n_layer
for layer, past in enumerate(pasts):
- h, present = block(h, 'h%d' % layer, past=past, hparams=hparams)
- if layer == 10:
- tf.add_to_collection('checkpoints', h)
+ h, present, x1 = block(h, 'h%d' % layer, past=past, hparams=hparams)
+ if layer < 48:
+ tf.add_to_collection('checkpoints', x1)
presents.append(present)
results['present'] = tf.stack(presents, axis=1)
h = norm(h, 'ln_f')
diff --git a/train.py b/train.py
index 57e4ef9..0cb705e 100755
--- a/train.py
+++ b/train.py
@@ -118,9 +118,9 @@ def main():
train_vars = [v for v in all_vars if '/h' in v.name] if args.only_train_transformer_layers else all_vars
if args.optimizer == 'adam':
- opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
+ opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate / args.batch_size)
elif args.optimizer == 'sgd':
- opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate)
+ opt = tf.train.GradientDescentOptimizer(learning_rate=args.learning_rate / args.batch_size)
else:
exit('Bad optimizer:', args.optimizer)
@@ -136,7 +136,7 @@ def main():
summary_loss = tf.summary.scalar('loss', opt_apply)
else:
if args.memory_saving_gradients:
- opt_grads = memory_saving_gradients.gradients(loss, train_vars)
+ opt_grads = memory_saving_gradients.gradients(loss, train_vars, checkpoints='collection')
else:
opt_grads = tf.gradients(loss, train_vars)
opt_grads = list(zip(opt_grads, train_vars))
@@ -219,7 +219,7 @@ def main():
tf_sample,
feed_dict={context: args.batch_size * [context_tokens]})
for i in range(min(args.sample_num - index, args.batch_size)):
- text = enc.decode(out[i])
+ text = enc.decode(out[i]).encode('utf-8')
text = '======== SAMPLE {} ========\n{}\n'.format(
index + 1, text)
all_text.append(text)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment