This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def build_datasets_vocab(root_file, captions_file, transform, split=0.15): | |
df = pd.read_csv(captions_file) | |
vocab = {} | |
def create_vocab(caption): | |
tokens = [token.lower() for token in word_tokenize(caption)] | |
for token in tokens: | |
if token not in vocab: | |
vocab[token] = len(vocab) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Encoder(nn.Module): | |
def __init__(self, in_channels, out_channels, image_dim, latent_dim): | |
super().__init__() | |
# constants used | |
iW, iH = image_dim | |
hW, hH = iW//POOLING_FACTOR, iH//POOLING_FACTOR |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Decoder(nn.Module): | |
def __init__(self, in_channels, out_channels, image_dim): | |
super().__init__() | |
iW, iH = image_dim | |
hW, hH = iW//POOLING_FACTOR, iH//POOLING_FACTOR | |
self.layer4 = nn.Sequential( |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def calculate_loss(reconstructed, caption_prob, images, captions_transformed, mean, log_std): | |
size = captions_transformed.shape[0] | |
reconstruction_error = criterion(reconstructed, images) | |
likelihoods = torch.stack([ | |
caption_prob[i, np.arange(MAX_CAPTION_LEN), captions_transformed[i]] for i in range(size)]) | |
log_likelihoods = -torch.log(likelihoods).sum() | |
KL_divergence = - (1 - mean.pow(2) - torch.exp(2 * log_std) + (2 *log_std)).sum() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
plt.imshow(images[4].to("cpu").permute(1, 2, 0)) | |
plt.axis("off") | |
_ = plt.title(get_caption(model.generate_caption(images[4].unsqueeze(0)))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def prep_test_data(med, train_dir, test_dir): | |
pop = os.listdir(train_dir+'/'+med) | |
test_data=random.sample(pop, 2000) | |
print(test_data) | |
for f in test_data: | |
shutil.copy(train_dir+'/'+med+'/'+f, test_dir+'/'+med+'/') | |
for medi in os.listdir(train_dir): | |
prep_test_data(medi, train_dir, test_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#for train | |
target_classes = os.listdir(train_dir) | |
num_classes = len(target_classes) | |
print('Number of target classes:', num_classes) | |
print(list(enumerate(target_classes))) | |
#for test | |
target_classes = os.listdir(test_dir) | |
num_classes = len(target_classes) | |
print('Number of target classes:', num_classes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def show_mri(med): | |
num = len(med) | |
if num == 0: | |
return None | |
rows = int(math.sqrt(num)) | |
cols = (num+1)//rows | |
f, axs = plt.subplots(rows, cols) | |
fig = 0 | |
for b in med: | |
img = image.load_img(b) |
OlderNewer