Skip to content

Instantly share code, notes, and snippets.

@hushell
Created August 17, 2022 12:58
Show Gist options
  • Save hushell/1043dda422ed2496a5161eb94c1a0bdf to your computer and use it in GitHub Desktop.
Save hushell/1043dda422ed2496a5161eb94c1a0bdf to your computer and use it in GitHub Desktop.
import torch
from srt.encoder import SRTEncoder
from srt.decoder import SRTDecoder, NerfDecoder
import yaml
batch_size = 1
num_views = 2
device = torch.device("cuda:0")
encoder = SRTEncoder().to(device)
decoder = NerfDecoder().to(device)
images = torch.rand(batch_size, num_views, 3, 224, 224).to(device)
i_rays = torch.rand(batch_size, num_views, 224, 224, 3).to(device)
i_positions = torch.rand(batch_size, num_views, 3).to(device)
t_rays = torch.rand(batch_size, 224*224, 3).to(device)
t_positions = torch.rand(batch_size, 224*224, 3).to(device)
z = encoder(images, i_positions, i_rays)
recons = decoder(z, t_positions, t_rays)
loss = torch.nn.functional.mse_loss(recons, images)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment