Created
April 7, 2015 20:52
-
-
Save rossant/02dda655690ab3a2a92b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import print_function | |
import os | |
import os.path as op | |
import shutil | |
from pprint import pprint | |
from timeit import default_timer | |
import h5py | |
import numpy as np | |
from numpy.testing import assert_allclose as ac | |
import phy | |
from phy.cluster.manual.store import DiskStore | |
from phy.io.h5 import open_h5 | |
from phy.cluster.manual._utils import _spikes_per_cluster | |
from phy.utils.array import _index_of | |
phy.debug() | |
_store_path = '_store' | |
n_spikes = 200000 | |
n_channels = 200 | |
n_clusters = 100 | |
# Generate the dataset. | |
def _gen_arr(): | |
arr = np.random.rand(n_spikes, n_channels).astype(np.float32) | |
with open_h5('test', 'w') as f: | |
f.write('/test', arr) | |
def _gen_spike_clusters(): | |
sc = np.random.randint(size=n_spikes, low=0, high=n_clusters) | |
with open_h5('sc', 'w') as f: | |
f.write('/sc', sc) | |
def _load_spike_clusters(): | |
with open_h5('sc', 'r') as f: | |
return f.read('/sc')[...] | |
def _reset_store(): | |
for path in (_store_path, '_flat'): | |
if op.exists(path): | |
shutil.rmtree(path) | |
os.mkdir(path) | |
# _gen_spike_clusters() | |
# _gen_arr() | |
f = open_h5('test', 'r') | |
sc = _load_spike_clusters() | |
arr = f.read('/test') | |
spikes = np.arange(n_spikes) | |
spc = _spikes_per_cluster(spikes, sc) | |
def _flat_file(cluster): | |
return op.join('_flat', str(cluster)) | |
@profile | |
def _gen_store_1(): | |
_reset_store() | |
chunk_size = 10000 | |
# print("chunks") | |
for i in range(n_spikes // chunk_size): | |
# print(i, end='\r') | |
a, b = i * chunk_size, (i + 1) * chunk_size | |
# Load a chunk from HDF5. | |
assert isinstance(arr, h5py.Dataset) | |
sub_arr = arr[a:b] | |
assert isinstance(sub_arr, np.ndarray) | |
sub_sc = sc[a:b] | |
sub_spikes = np.arange(a, b) | |
# Split the spikes. | |
sub_spc = _spikes_per_cluster(sub_spikes, sub_sc) | |
# Go through the clusters. | |
clusters = sorted(sub_spc.keys()) | |
for cluster in clusters: | |
idx = _index_of(sub_spc[cluster], sub_spikes) | |
# Save part of the array to a binary file. | |
with open(_flat_file(cluster), 'ab') as f: | |
sub_arr[idx].tofile(f) | |
# print() | |
ds = DiskStore(_store_path) | |
# Next, put the flat binary files back to HDF5. | |
# print("flat to HDF5") | |
for cluster in range(n_clusters): | |
# print(cluster, end='\r') | |
data = np.fromfile(_flat_file(cluster), | |
dtype=np.float32).reshape((-1, n_channels)) | |
ds.store(cluster, data=data) | |
# print() | |
# Test. | |
cluster = 0 | |
arr2 = ds.load(cluster, 'data') | |
ac(arr[spc[cluster], :], arr2) | |
_gen_store_1() | |
f.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment