Skip to content

Instantly share code, notes, and snippets.

diff --git a/offline/main.py b/offline/main.py
index 5502749..6647508 100644
--- a/offline/main.py
+++ b/offline/main.py
@@ -43,6 +43,7 @@ class Conf(BaseConf):
total_optim_steps: int = attrs.field(default=int(2e5), validator=attrs.validators.gt(0))
log_steps: int = attrs.field(default=250, validator=attrs.validators.gt(0)) # type: ignore
+ eval_before_training: bool = False
eval_steps: int = attrs.field(default=20000, validator=attrs.validators.gt(0)) # type: ignore
diff --git a/offline/main.py b/offline/main.py
index 5502749..1581d30 100644
--- a/offline/main.py
+++ b/offline/main.py
@@ -43,6 +43,7 @@ class Conf(BaseConf):
total_optim_steps: int = attrs.field(default=int(2e5), validator=attrs.validators.gt(0))
log_steps: int = attrs.field(default=250, validator=attrs.validators.gt(0)) # type: ignore
+ eval_before_training: bool = False
eval_steps: int = attrs.field(default=20000, validator=attrs.validators.gt(0)) # type: ignore
diff --git a/offline/main.py b/offline/main.py
index 5502749..1581d30 100644
--- a/offline/main.py
+++ b/offline/main.py
@@ -43,6 +43,7 @@ class Conf(BaseConf):
total_optim_steps: int = attrs.field(default=int(2e5), validator=attrs.validators.gt(0))
log_steps: int = attrs.field(default=250, validator=attrs.validators.gt(0)) # type: ignore
+ eval_before_training: bool = False
eval_steps: int = attrs.field(default=20000, validator=attrs.validators.gt(0)) # type: ignore
diff --git a/offline/main.py b/offline/main.py
index 1581d30..5502749 100644
--- a/offline/main.py
+++ b/offline/main.py
@@ -43,7 +43,6 @@ class Conf(BaseConf):
total_optim_steps: int = attrs.field(default=int(2e5), validator=attrs.validators.gt(0))
log_steps: int = attrs.field(default=250, validator=attrs.validators.gt(0)) # type: ignore
- eval_before_training: bool = False
eval_steps: int = attrs.field(default=20000, validator=attrs.validators.gt(0)) # type: ignore
diff --git a/offline/main.py b/offline/main.py
index 1581d30..6647508 100644
--- a/offline/main.py
+++ b/offline/main.py
@@ -58,7 +58,7 @@ cs.store(name='config', node=Conf())
@hydra.main(version_base=None, config_name="config")
def train(dict_cfg: DictConfig):
cfg: Conf = Conf.from_DictConfig(dict_cfg) # type: ignore
- wandb_run = cfg.setup_for_experiment() # checking & setup logging
+ cfg.setup_for_experiment() # checking & setup logging
diff --git a/offline/main.py b/offline/main.py
index 1581d30..6647508 100644
--- a/offline/main.py
+++ b/offline/main.py
@@ -58,7 +58,7 @@ cs.store(name='config', node=Conf())
@hydra.main(version_base=None, config_name="config")
def train(dict_cfg: DictConfig):
cfg: Conf = Conf.from_DictConfig(dict_cfg) # type: ignore
- wandb_run = cfg.setup_for_experiment() # checking & setup logging
+ cfg.setup_for_experiment() # checking & setup logging
diff --git a/train_cifar10.py b/train_cifar10.py
index b52de81..0592f67 100644
--- a/train_cifar10.py
+++ b/train_cifar10.py
@@ -65,13 +65,16 @@ def build_discriminator(image_size, latent_code_length):
y = Conv2D(1024, (3, 3), padding="same")(y)
y = LeakyReLU()(y)
y = Flatten()(y)
- y = Dense(1,activation="sigmoid")(y)
+ y = Dense(1)(y)
diff --git a/tf_bigan.py b/tf_bigan.py
index 12cf97b..5d6d88c 100644
--- a/tf_bigan.py
+++ b/tf_bigan.py
@@ -48,8 +48,8 @@ parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--netG', default='', help='path to netG (to continue training)')
parser.add_argument('--netD', default='', help='path to netD (to continue training)')
parser.add_argument('--netE', default='', help='path to netE (to continue training)')
-parser.add_argument('--checkpoints', default='/data/vision/torralba/scratch/dxwu/dcbigan/checkpoints/', help='folder to output model checkpoints')
-parser.add_argument('--samples', default='/data/vision/torralba/scratch/dxwu/dcbigan/samples/', help='folder to output images')
import torch
import torch.utils.data
class ChunkDataset(object):
def __init__(self, chunk_lengths):
self.chunk_lengths = tuple(chunk_lengths)
def __getitem__(self, key):
chunk_idx, indices_in_chunk = key
return self.get_data_from_chunk(chunk_idx, indices_in_chunk)
@ssnl
ssnl / view_im2col.py
Created July 14, 2018 10:43
view_im2col.py
# inp:
# N, C, H ,W
#
# out:
# N, IH, IW, C, KH, KW
#
# kernel_size:
# KW, KW
def im2col(inp, kernel_size, stride=(1, 1), padding=(0, 0), dilation=(1, 1)):
assert padding == (0, 0)