Prototype of dtype compatibility check
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from __future__ import division, print_function | |
from numpy import object_ | |
from fractions import gcd | |
def get_object_offsets(dtype, base_offset=0): | |
offsets = [] | |
if dtype.fields is not None: | |
for field in dtype.fields.values(): | |
sub_dtype = field[0] | |
sub_offset = field[1] + base_offset | |
offsets.extend(get_object_offsets(sub_dtype, sub_offset)) | |
else: | |
if dtype.shape: | |
sub_offsets = get_object_offsets(dtype.base, base_offset) | |
count = 1 | |
for dim in dtype.shape: | |
count *= dim | |
offsets.extend(off + dtype.base.itemsize*j for j in range(count) | |
for off in sub_offsets) | |
elif dtype.type == object_: | |
offsets.append(base_offset) | |
return offsets | |
def dtype_view_is_safe(newtype, oldtype): | |
new_size = newtype.itemsize | |
new_offsets = get_object_offsets(newtype) | |
old_size = oldtype.itemsize | |
old_offsets = get_object_offsets(oldtype) | |
# Unless their sizes are equal, old_num occurrences of oldtype will | |
# be replaced by new_num of newtype. | |
if old_size == new_size: | |
new_num = old_num = 1 | |
else: | |
gcd_new_old = gcd(new_size, old_size) | |
new_num = old_size // gcd_new_old | |
old_num = new_size // gcd_new_old | |
# There must be the same total number of objects before and after | |
if len(new_offsets)*new_num != len(old_offsets)*old_num: | |
return False | |
# Is there a smarter way of doing this than actually constructing | |
# both sets of tiled object offsets? | |
new_offsets = set(off + new_size*j for j in range(new_num) | |
for off in new_offsets) | |
old_offsets = set(off + old_size*j for j in range(old_num) | |
for off in old_offsets) | |
return new_offsets == old_offsets | |
if __name__ == '__main__': | |
import numpy as np | |
obj_size = np.dtype(object_).itemsize | |
# Some tests for get_object_offsets | |
dta = np.dtype(('O', (3,))) | |
dtb = np.dtype((dta, (2,))) | |
dtc = np.dtype([('', 'i2'), ('', 'i1'), ('', dtb), ('', 'i1'), ('', 'O')]) | |
dta_offsets = [obj_size*j for j in range(3)] | |
dtb_offsets = [obj_size*j for j in range(3*2)] | |
dtc_offsets = [off+3 for off in dtb_offsets] + [dtb.itemsize+4] | |
assert dta_offsets == sorted(get_object_offsets(dta)) | |
assert dtb_offsets == sorted(get_object_offsets(dtb)) | |
assert dtc_offsets == sorted(get_object_offsets(dtc)) | |
# Some tests for dtype_view_is_safe | |
dt_base = np.dtype([('', 'O'), ('', 'i2')]) | |
for n in range(1, 10): | |
for m in range(1, 10): | |
assert dtype_view_is_safe(np.dtype(dt_base, (n,)), | |
np.dtype(dt_base, (m,))) | |
assert dtype_view_is_safe(np.dtype('O', (m,)), np.dtype('O', (n,))) | |
dta = np.dtype([('', 'i1'), ('', 'i1'), ('', 'O'), ('', 'i4'), ('', 'O'), | |
('', 'i2')]) | |
dtb = np.dtype([('', 'i2'), ('', 'O'), ('', 'i2'), ('', 'i2'), ('', 'O'), | |
('', 'i4'), ('', 'O'), ('', 'i1'), ('', 'i1')]) | |
assert dtype_view_is_safe(dta, dtb) | |
dta = np.dtype([('', 'O'), ('', 'O'), ('', 'i1')]) | |
dtb = np.dtype([('', 'O'), ('', 'O'), ('', 'i1'), ('', 'O'), ('', 'O')]) | |
assert not dtype_view_is_safe(dta, dtb) | |
dtb = np.dtype([('', 'O'), ('', 'O'), ('', 'i1'), ('', 'O')]) | |
assert not dtype_view_is_safe(dta, dtb) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment