Skip to content

Instantly share code, notes, and snippets.

@rthompsonj
Created January 15, 2017 00:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rthompsonj/547e39a2820a57f733d35a179b1d0c7a to your computer and use it in GitHub Desktop.
Save rthompsonj/547e39a2820a57f733d35a179b1d0c7a to your computer and use it in GitHub Desktop.
import numpy as np
import struct
import sys
FREE_TYPES = [2,3,5]
INFILE = sys.argv[1]
if len(sys.argv) > 2:
OUTFILE = sys.argv[2]
#else:
# OUTPUT = '
nFREE = len(FREE_TYPES)
def skip(f):
skipval = np.fromfile(f,dtype=np.uint32,count=1)
return skipval[0]
def errorcheck(s1,s2,block):
if s1!=s2:
print 'issue with block %s >> %d vs %d' % (block,s1,s2)
sys.exit()
class ParticleType(object):
def __init__(self,ptype,npart,massTable):
self.ptype = ptype
self.npart = npart
self.mass = np.array(massTable)
print 'Scanning %s for masses...' % INFILE
f = open(INFILE,'rb')
## header
s1 = skip(f)
npart = np.fromfile(f,dtype=np.uint32,count=6)
massTable = np.fromfile(f,dtype=np.float64,count=6)
time = np.fromfile(f,dtype=np.float64,count=1)[0]
z = np.fromfile(f,dtype=np.float64,count=1)[0]
flag_sfr = np.fromfile(f,dtype=np.int32,count=1)[0]
flag_fb = np.fromfile(f,dtype=np.int32,count=1)[0]
npartTotal = np.fromfile(f,dtype=np.uint32,count=6)
fcool = np.fromfile(f,dtype=np.int32,count=1)[0]
nfiles = np.fromfile(f,dtype=np.int32,count=1)[0]
boxsize = np.fromfile(f,dtype=np.float64,count=1)[0]
Omega0 = np.fromfile(f,dtype=np.float64,count=1)[0]
OmegaLambda = np.fromfile(f,dtype=np.float64,count=1)[0]
HubbleParam = np.fromfile(f,dtype=np.float64,count=1)[0]
flag_age = np.fromfile(f,dtype=np.int32,count=1)[0]
flag_metals = np.fromfile(f,dtype=np.int32,count=1)[0]
NallHW = np.fromfile(f,dtype=np.uint32,count=6)
flag_entropy = np.fromfile(f,dtype=np.int32,count=1)[0]
flag_doublep = np.fromfile(f,dtype=np.int32,count=1)[0]
bytes_left = 256 + 4 - f.tell()
f.seek(bytes_left,1)
s2 = skip(f)
errorcheck(s1,s2,'header')
# skip positions
s1 = skip(f)
for i in range(0,6):
f.seek(4 * 3 * npart[i],1)
s2 = skip(f)
errorcheck(s1,s2,'positions')
# skip velocities
s1 = skip(f)
for i in range(0,6):
f.seek(4 * 3 * npart[i],1)
s2 = skip(f)
errorcheck(s1,s2,'velocities')
# skip PIDs
s1 = skip(f)
for i in range(0,6):
f.seek(4 * npart[i],1)
s2 = skip(f)
errorcheck(s1,s2,'PIDs')
# read masses
## test if MB is present
mbpresent = False
for i in range(0,6):
if massTable[i] == 0 and npart[i] > 0:
mbpresent = True
break
if mbpresent == False:
print 'no mass block!'
sys.exit()
masses = {}
s1 = skip(f)
for i in range(0,6):
if massTable[i] == 0 and npart[i] > 0:
masses[i] = np.fromfile(f,dtype=np.float32,count=npart[i])
s2 = skip(f)
errorcheck(s1,s2,'masses')
# gather ALL unique masses
unique_masses = {}
unique_indexes = {}
for k,v in masses.iteritems():
cur = np.unique(masses[k])
for i in range(0,len(cur)):
unique_masses[cur[i]] = k
unique_indexes[cur[i]] = np.where(masses[k] == cur[i])[0]
# sort them from lowest->highest mass
sorted_keys = np.sort(unique_masses.keys())
## make a copy to retain original particle types
orig_unique_masses = dict(unique_masses)
def logFound(mass,origPT,newPT):
print 'Found mass %0.2e [%d] --> [%d]' % (mass,origPT,newPT)
nUNIQUE = len(sorted_keys)
if nUNIQUE == 0:
print 'no mixed-mass particle types'
sys.exit()
elif nUNIQUE <= nFREE:
for i in range(0,nUNIQUE):
logFound(sorted_keys[i],unique_masses[sorted_keys[i]],FREE_TYPES[i])
unique_masses[sorted_keys[i]] = FREE_TYPES[i]
else:
print 'WARNING, mixing %d types' % (nUNIQUE-nFREE+1)
for i in range(0,nFREE):
logFound(sorted_keys[i],unique_masses[sorted_keys[i]],FREE_TYPES[i])
unique_masses[sorted_keys[i]] = FREE_TYPES[i]
for i in range(nFREE,nUNIQUE):
logFound(sorted_keys[i],unique_masses[sorted_keys[i]],FREE_TYPES[-1])
unique_masses[sorted_keys[i]] = FREE_TYPES[-1]
## setup objects
PARTICLES = []
for i in range(0,6):
PARTICLES.append(ParticleType(i,npart[i],massTable[i]))
print 'Reading data from %s...' % INFILE
## rewind and READ IN DATA
f.seek(0,0)
s1 = skip(f)
f.seek(256,1)
s2 = skip(f)
errorcheck(s1,s2,'header')
## positions
s1 = skip(f)
for i in range(0,6):
data = np.fromfile(f,dtype=np.float32,count=npart[i]*3)
data = data.reshape(npart[i],3)
PARTICLES[i].positions = data
s2 = skip(f)
errorcheck(s1,s2,'positions')
## velocities
s1 = skip(f)
for i in range(0,6):
data = np.fromfile(f,dtype=np.float32,count=npart[i]*3)
data = data.reshape(npart[i],3)
PARTICLES[i].velocities = data
s2 = skip(f)
errorcheck(s1,s2,'velocities')
## PIDs
s1 = skip(f)
for i in range(0,6):
data = np.fromfile(f,dtype=np.uint32,count=npart[i])
PARTICLES[i].pid = data
s2 = skip(f)
errorcheck(s1,s2,'pid2')
## mass
s1 = skip(f)
for i in range(0,6):
if npart[i] > 0 and massTable[i] == 0:
f.seek(4 * npart[i],1)
s2 = skip(f)
errorcheck(s1,s2,'mass')
## u
if npart[0] > 0:
s1 = skip(f)
data = np.fromfile(f,dtype=np.float32,count=npart[0])
PARTICLES[0].u = data
s2 = skip(f)
errorcheck(s1,s2,'u')
f.close()
#######################################
## sort data into new particle types ##
#######################################
print 'Moving particles...'
indexes_to_delete = {}
for k,v in unique_masses.iteritems():
# k = mass, v = ptype
orig_ptype = orig_unique_masses[k]
new_ptype = unique_masses[k]
if orig_ptype == new_ptype:
PARTICLES[new_ptype].mass = masses[orig_ptype]
elif orig_ptype != new_ptype:
## sort out the masses
indexes = unique_indexes[k]
PARTICLES[new_ptype].mass = masses[orig_ptype][indexes]
if orig_ptype in indexes_to_delete:
indexes_to_delete[orig_ptype] = np.append(indexes_to_delete[orig_ptype],indexes)
else:
indexes_to_delete[orig_ptype] = indexes
## sort out positions
curPOS = PARTICLES[orig_ptype].positions[indexes]
newPOS = PARTICLES[new_ptype].positions
if np.shape(newPOS)[0] == 0:
PARTICLES[new_ptype].positions = curPOS
else:
print 'ehhh pos'
sys.exit()
## sort out velocities
curVEL = PARTICLES[orig_ptype].velocities[indexes]
newVEL = PARTICLES[new_ptype].velocities
if np.shape(newPOS)[0] == 0:
PARTICLES[new_ptype].velocities = curVEL
else:
print 'ehhh vel'
sys.exit()
## sort out PIDs
curPID = PARTICLES[orig_ptype].pid[indexes]
newPID = PARTICLES[new_ptype].pid
if np.shape(newPOS)[0] == 0:
PARTICLES[new_ptype].pid = curPID
else:
print 'ehhh pid'
sys.exit()
## now npart
npart[orig_ptype] -= len(indexes)
npart[new_ptype] += len(indexes)
npartTotal[orig_ptype] -= len(indexes)
npartTotal[new_ptype] += len(indexes)
## delete mass entries
for k,v in indexes_to_delete.iteritems():
#print 'deleting from particle set %d' % k
PARTICLES[k].positions = np.delete(PARTICLES[k].positions ,v,0)
PARTICLES[k].velocities = np.delete(PARTICLES[k].velocities,v,0)
PARTICLES[k].pid = np.delete(PARTICLES[k].pid ,v)
PARTICLES[k].mass = np.delete(PARTICLES[k].mass,v)
masses[k] = np.delete(masses[k],v)
## update massTable
for i in range(0,len(PARTICLES)):
ptypeMasses = np.unique(PARTICLES[i].mass)
if len(ptypeMasses) == 1:
massTable[i] = ptypeMasses[0]
PARTICLES[i].mass = 0.0
else:
massTable[i] = 0.0
PARTICLES[i].mass = masses[i]
#sys.exit()
print 'Writing %s...' % OUTFILE
f = open(OUTFILE,'wb')
## HEADER ##
f.write(struct.pack('<I',256))
f.write(struct.pack('<6I',*npart))
f.write(struct.pack('<6d',*massTable))
f.write(struct.pack('<d',time))
f.write(struct.pack('<d',z))
f.write(struct.pack('<i',flag_sfr))
f.write(struct.pack('<i',flag_fb))
f.write(struct.pack('<6I',*npartTotal))
f.write(struct.pack('<i',fcool))
f.write(struct.pack('<i',nfiles))
f.write(struct.pack('<d',boxsize))
f.write(struct.pack('<d',Omega0))
f.write(struct.pack('<d',OmegaLambda))
f.write(struct.pack('<d',HubbleParam))
f.write(struct.pack('<i',flag_age))
f.write(struct.pack('<i',flag_metals))
f.write(struct.pack('<6I',*NallHW))
f.write(struct.pack('<i',flag_entropy))
f.write(struct.pack('<i',flag_doublep))
header_bytes_left = 260 - f.tell()
for j in range(header_bytes_left):
f.write(struct.pack('<x'))
f.write(struct.pack('<I',256))
if f.tell()-8 != 256:
print(r'ERROR! output header = %d' % (f.tell()-8))
sys.exit()
#print(r'header written')
ntot = np.sum(npart)
## positions
f.write(struct.pack('<I',12*ntot))
for i in range(0,6):
if npart[i] > 0:
data = PARTICLES[i].positions.astype('f')
f.write(data.tostring())
f.write(struct.pack('<I',12*ntot))
## velocities
f.write(struct.pack('<I',12*ntot))
for i in range(0,6):
if npart[i] > 0:
data = PARTICLES[i].velocities.astype('f')
f.write(data.tostring())
f.write(struct.pack('<I',12*ntot))
## PID
f.write(struct.pack('<I',4*ntot))
for i in range(0,6):
if npart[i] > 0:
data = PARTICLES[i].pid.astype('I')
f.write(data.tostring())
f.write(struct.pack('<I',4*ntot))
## mass
mbpresent = 0
for i in range(0,6):
if npart[i] > 0 and massTable[i] == 0:
mbpresent += npart[i]
if mbpresent > 0:
f.write(struct.pack('<I',4*mbpresent))
for i in range(0,6):
if npart[i] > 0 and massTable[i] == 0:
data = PARTICLES[i].mass.astype('f')
f.write(data.tostring())
f.write(struct.pack('<I',4*mbpresent))
if npart[0] > 0:
f.write(struct.pack('<I',4*npart[0]))
data = PARTICLES[0].u.astype('f')
f.write(data.tostring())
f.write(struct.pack('<I',4*npart[0]))
f.close()
print 'Complete'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment