Last active
August 25, 2021 15:42
-
-
Save xmodar/ac28eccbfd12ee3a55c7896aa31dc94b to your computer and use it in GitHub Desktop.
ModelNet40 Dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ssl | |
import urllib | |
from pathlib import Path | |
import torch | |
from torch.utils.data import Dataset | |
from torchvision.datasets.utils import extract_archive, check_integrity | |
import h5py | |
import pandas as pd | |
from tqdm import tqdm | |
from pyntcloud import PyntCloud | |
__all__ = ['ModelNet40'] | |
def download_and_extract_archive(url, path, md5=None, extract_to='data'): | |
# works even if the SSL certificate is expired for the link | |
path = Path(path) | |
path.mkdir(parents=True, exist_ok=True) | |
extract_path = path / extract_to | |
if extract_path.exists(): | |
return extract_path | |
file_path = path / Path(url).name | |
if not file_path.exists() or not check_integrity(file_path, md5): | |
print(f'{file_path} not found or corrupted') | |
print(f'downloading from {url}') | |
context = ssl.SSLContext() | |
with urllib.request.urlopen(url, context=context) as response: | |
with tqdm(total=response.length) as pbar: | |
with open(file_path, 'wb') as file: | |
chunk_size = 1024 | |
chunks = iter(lambda: response.read(chunk_size), '') | |
for chunk in chunks: | |
if not chunk: | |
break | |
pbar.update(chunk_size) | |
file.write(chunk) | |
extract_archive(str(file_path), str(extract_path)) | |
return extract_path | |
class ModelNet40(Dataset): | |
"""ModelNet40 dataset""" | |
dir_name = 'modelnet40_ply_hdf5_2048' | |
md5 = 'c9ab8e6dfb16f67afdab25e155c79e59' | |
url = f'https://shapenet.cs.stanford.edu/media/{dir_name}.zip' | |
def __init__(self, root=None, train=True, transform=None, max_count=1024): | |
self.train = bool(train) | |
self.transform = transform | |
self.max_count = int(max_count) | |
if root is None: | |
root = Path(torch.hub.get_dir()) / 'datasets/ModelNet40' | |
self.root = Path(root) | |
path = download_and_extract_archive(self.url, self.root, self.md5) | |
path = path / self.dir_name | |
labels = path / 'shape_names.txt' | |
self.labels = open(labels, 'r').read().strip().split('\n') | |
split = 'train' if self.train else 'test' | |
files = path / f'{split}_files.txt' | |
self.files = open(files, 'r').read().strip().split('\n') | |
self.lengths = [ | |
len(h5py.File(self.root / f, 'r')['data']) for f in self.files | |
] | |
def __getitem__(self, index): | |
file_name, index = self.get_file(index) | |
with h5py.File(self.root / file_name, 'r') as file: | |
point_cloud = file['data'][index, :self.max_count] | |
label = file['label'][index] | |
point_cloud = torch.from_numpy(point_cloud).T.float() | |
label = torch.from_numpy(label).squeeze().long() | |
if self.transform is not None: | |
point_cloud = self.transform(point_cloud, self.train) | |
return point_cloud, label | |
def __len__(self): | |
return sum(self.lengths) | |
def __repr__(self): | |
return f'{type(self).__name__}(train={self.train})' | |
def get_file(self, index): | |
total = 0 | |
for file_name, length in zip(self.files, self.lengths): | |
if index < total + length: | |
return file_name, index - total | |
total += length | |
raise IndexError(f'{index} should be in [0, {len(self)})') | |
def show(self, index=None): | |
if index is None: | |
index = torch.randint(len(self), ()) | |
point_cloud, label = self[index] | |
print(f'instance #{index}: {self.labels[label]}') | |
df = pd.DataFrame(point_cloud.T.numpy(), columns=list('xyz')) | |
PyntCloud(df).plot(elev=60, azim=-90) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment