Skip to content

Instantly share code, notes, and snippets.

@fjxmlzn
Created October 11, 2020 16:24
Show Gist options
  • Save fjxmlzn/fc61538ae69bf3633334a00401d5b3a6 to your computer and use it in GitHub Desktop.
Save fjxmlzn/fc61538ae69bf3633334a00401d5b3a6 to your computer and use it in GitHub Desktop.
import sys
sys.path.append("..")
import os
import tensorflow as tf
from gan.load_data import load_data
from gan.network import DoppelGANgerGenerator, Discriminator, \
RNNInitialStateType, AttrDiscriminator
from gan.doppelganger import DoppelGANger
from gan import output
from gan.util import add_gen_flag, normalize_per_sample, \
renormalize_per_sample
import numpy as np
sys.modules["output"] = output
if __name__ == "__main__":
_config ={
"batch_size": 100,
"vis_freq": 200,
"vis_num_sample": 5,
"d_rounds": 1,
"g_rounds": 1,
"num_packing": 1,
"noise": True,
"feed_back": False,
"g_lr": 0.001,
"d_lr": 0.001,
"d_gp_coe": 10.0,
"gen_feature_num_layers": 1,
"gen_feature_num_units": 100,
"gen_attribute_num_layers": 3,
"gen_attribute_num_units": 100,
"disc_num_layers": 5,
"disc_num_units": 200,
"initial_state": "random",
"attr_d_lr": 0.001,
"attr_d_gp_coe": 10.0,
"g_attr_d_coe": 1.0,
"attr_disc_num_layers": 5,
"attr_disc_num_units": 200,
"generate_num_train_sample": 50000,
"generate_num_test_sample": 50000,
"dataset": "web",
"epoch": 400,
"sample_len": 10,
"extra_checkpoint_freq": 5,
"epoch_checkpoint_freq": 1,
"aux_disc": True,
"self_norm": True
}
_work_dir = "."
(data_feature, data_attribute,
data_gen_flag,
data_feature_outputs, data_attribute_outputs) = \
load_data(os.path.join("..", "data", _config["dataset"]))
print(data_feature.shape)
print(data_attribute.shape)
print(data_gen_flag.shape)
num_real_attribute = len(data_attribute_outputs)
if _config["self_norm"]:
(data_feature, data_attribute, data_attribute_outputs,
real_attribute_mask) = \
normalize_per_sample(
data_feature, data_attribute, data_feature_outputs,
data_attribute_outputs)
else:
real_attribute_mask = [True] * len(data_attribute_outputs)
sample_len = _config["sample_len"]
data_feature, data_feature_outputs = add_gen_flag(
data_feature,
data_gen_flag,
data_feature_outputs,
sample_len)
print(data_feature.shape)
print(len(data_feature_outputs))
initial_state = None
if _config["initial_state"] == "variable":
initial_state = RNNInitialStateType.VARIABLE
elif _config["initial_state"] == "random":
initial_state = RNNInitialStateType.RANDOM
elif _config["initial_state"] == "zero":
initial_state = RNNInitialStateType.ZERO
else:
raise NotImplementedError
generator = DoppelGANgerGenerator(
feed_back=_config["feed_back"],
noise=_config["noise"],
feature_outputs=data_feature_outputs,
attribute_outputs=data_attribute_outputs,
real_attribute_mask=real_attribute_mask,
sample_len=sample_len,
feature_num_layers=_config["gen_feature_num_layers"],
feature_num_units=_config["gen_feature_num_units"],
attribute_num_layers=_config["gen_attribute_num_layers"],
attribute_num_units=_config["gen_attribute_num_units"],
initial_state=initial_state)
discriminator = Discriminator(
num_layers=_config["disc_num_layers"],
num_units=_config["disc_num_units"])
if _config["aux_disc"]:
attr_discriminator = AttrDiscriminator(
num_layers=_config["attr_disc_num_layers"],
num_units=_config["attr_disc_num_units"])
checkpoint_dir = os.path.join(_work_dir, "checkpoint")
sample_dir = os.path.join(_work_dir, "sample")
time_path = os.path.join(_work_dir, "time.txt")
run_config = tf.ConfigProto()
with tf.Session(config=run_config) as sess:
gan = DoppelGANger(
sess=sess,
checkpoint_dir=checkpoint_dir,
sample_dir=sample_dir,
time_path=time_path,
epoch=_config["epoch"],
batch_size=_config["batch_size"],
data_feature=data_feature,
data_attribute=data_attribute,
real_attribute_mask=real_attribute_mask,
data_gen_flag=data_gen_flag,
sample_len=sample_len,
data_feature_outputs=data_feature_outputs,
data_attribute_outputs=data_attribute_outputs,
vis_freq=_config["vis_freq"],
vis_num_sample=_config["vis_num_sample"],
generator=generator,
discriminator=discriminator,
attr_discriminator=(attr_discriminator
if _config["aux_disc"] else None),
d_gp_coe=_config["d_gp_coe"],
attr_d_gp_coe=(_config["attr_d_gp_coe"]
if _config["aux_disc"] else 0.0),
g_attr_d_coe=(_config["g_attr_d_coe"]
if _config["aux_disc"] else 0.0),
d_rounds=_config["d_rounds"],
g_rounds=_config["g_rounds"],
g_lr=_config["g_lr"],
d_lr=_config["d_lr"],
attr_d_lr=(_config["attr_d_lr"]
if _config["aux_disc"] else 0.0),
extra_checkpoint_freq=_config["extra_checkpoint_freq"],
epoch_checkpoint_freq=_config["epoch_checkpoint_freq"],
num_packing=_config["num_packing"])
gan.build()
print("Finished building")
total_generate_num_sample = \
(_config["generate_num_train_sample"] +
_config["generate_num_test_sample"])
if data_feature.shape[1] % sample_len != 0:
raise Exception("length must be a multiple of sample_len")
length = int(data_feature.shape[1] / sample_len)
real_attribute_input_noise = gan.gen_attribute_input_noise(
total_generate_num_sample)
addi_attribute_input_noise = gan.gen_attribute_input_noise(
total_generate_num_sample)
feature_input_noise = gan.gen_feature_input_noise(
total_generate_num_sample, length)
input_data = gan.gen_feature_input_data_free(
total_generate_num_sample)
mid_checkpoint_dir = "../../checkpoint"
if not os.path.exists(mid_checkpoint_dir):
print("Not found {}".format(mid_checkpoint_dir))
exit()
save_path = os.path.join(
_work_dir,
"generated_samples")
if not os.path.exists(save_path):
os.makedirs(save_path)
train_path_ori = os.path.join(
save_path, "generated_data_train_ori.npz")
test_path_ori = os.path.join(
save_path, "generated_data_test_ori.npz")
train_path = os.path.join(
save_path, "generated_data_train.npz")
test_path = os.path.join(
save_path, "generated_data_test.npz")
if os.path.exists(test_path):
print("Save_path {} exists".format(save_path))
exit()
gan.load(mid_checkpoint_dir)
print("Finished loading")
features, attributes, gen_flags, lengths = gan.sample_from(
real_attribute_input_noise, addi_attribute_input_noise,
feature_input_noise, input_data)
# specify given_attribute parameter, if you want to generate
# data according to an attribute
print(features.shape)
print(attributes.shape)
print(gen_flags.shape)
print(lengths.shape)
split = _config["generate_num_train_sample"]
if _config["self_norm"]:
np.savez(
train_path_ori,
data_feature=features[0: split],
data_attribute=attributes[0: split],
data_gen_flag=gen_flags[0: split])
np.savez(
test_path_ori,
data_feature=features[split:],
data_attribute=attributes[split:],
data_gen_flag=gen_flags[split:])
features, attributes = renormalize_per_sample(
features, attributes, data_feature_outputs,
data_attribute_outputs, gen_flags,
num_real_attribute=num_real_attribute)
print(features.shape)
print(attributes.shape)
np.savez(
train_path,
data_feature=features[0: split],
data_attribute=attributes[0: split],
data_gen_flag=gen_flags[0: split])
np.savez(
test_path,
data_feature=features[split:],
data_attribute=attributes[split:],
data_gen_flag=gen_flags[split:])
print("Done")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment