Skip to content

Instantly share code, notes, and snippets.

@xmodar
Last active August 25, 2021 15:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save xmodar/ac28eccbfd12ee3a55c7896aa31dc94b to your computer and use it in GitHub Desktop.
Save xmodar/ac28eccbfd12ee3a55c7896aa31dc94b to your computer and use it in GitHub Desktop.
ModelNet40 Dataset
import ssl
import urllib
from pathlib import Path
import torch
from torch.utils.data import Dataset
from torchvision.datasets.utils import extract_archive, check_integrity
import h5py
import pandas as pd
from tqdm import tqdm
from pyntcloud import PyntCloud
__all__ = ['ModelNet40']
def download_and_extract_archive(url, path, md5=None, extract_to='data'):
# works even if the SSL certificate is expired for the link
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
extract_path = path / extract_to
if extract_path.exists():
return extract_path
file_path = path / Path(url).name
if not file_path.exists() or not check_integrity(file_path, md5):
print(f'{file_path} not found or corrupted')
print(f'downloading from {url}')
context = ssl.SSLContext()
with urllib.request.urlopen(url, context=context) as response:
with tqdm(total=response.length) as pbar:
with open(file_path, 'wb') as file:
chunk_size = 1024
chunks = iter(lambda: response.read(chunk_size), '')
for chunk in chunks:
if not chunk:
break
pbar.update(chunk_size)
file.write(chunk)
extract_archive(str(file_path), str(extract_path))
return extract_path
class ModelNet40(Dataset):
"""ModelNet40 dataset"""
dir_name = 'modelnet40_ply_hdf5_2048'
md5 = 'c9ab8e6dfb16f67afdab25e155c79e59'
url = f'https://shapenet.cs.stanford.edu/media/{dir_name}.zip'
def __init__(self, root=None, train=True, transform=None, max_count=1024):
self.train = bool(train)
self.transform = transform
self.max_count = int(max_count)
if root is None:
root = Path(torch.hub.get_dir()) / 'datasets/ModelNet40'
self.root = Path(root)
path = download_and_extract_archive(self.url, self.root, self.md5)
path = path / self.dir_name
labels = path / 'shape_names.txt'
self.labels = open(labels, 'r').read().strip().split('\n')
split = 'train' if self.train else 'test'
files = path / f'{split}_files.txt'
self.files = open(files, 'r').read().strip().split('\n')
self.lengths = [
len(h5py.File(self.root / f, 'r')['data']) for f in self.files
]
def __getitem__(self, index):
file_name, index = self.get_file(index)
with h5py.File(self.root / file_name, 'r') as file:
point_cloud = file['data'][index, :self.max_count]
label = file['label'][index]
point_cloud = torch.from_numpy(point_cloud).T.float()
label = torch.from_numpy(label).squeeze().long()
if self.transform is not None:
point_cloud = self.transform(point_cloud, self.train)
return point_cloud, label
def __len__(self):
return sum(self.lengths)
def __repr__(self):
return f'{type(self).__name__}(train={self.train})'
def get_file(self, index):
total = 0
for file_name, length in zip(self.files, self.lengths):
if index < total + length:
return file_name, index - total
total += length
raise IndexError(f'{index} should be in [0, {len(self)})')
def show(self, index=None):
if index is None:
index = torch.randint(len(self), ())
point_cloud, label = self[index]
print(f'instance #{index}: {self.labels[label]}')
df = pd.DataFrame(point_cloud.T.numpy(), columns=list('xyz'))
PyntCloud(df).plot(elev=60, azim=-90)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment