Skip to content

Instantly share code, notes, and snippets.

@baileyji
Created November 1, 2017 16:52
Show Gist options
  • Save baileyji/1a1c4762bbc303f9f0a3a10503a5da64 to your computer and use it in GitHub Desktop.
Save baileyji/1a1c4762bbc303f9f0a3a10503a5da64 to your computer and use it in GitHub Desktop.
Scratch file to demonstrate a possible d2c mapping bug
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
from jwst import datamodels as dm
from helpers import *
def nanminmax(x):
use = np.isfinite(x)
if use.any():
return x[use].min(), x[use].max()
else:
return None, None
def pxrange(x):
mi, ma = nanminmax(x)
return np.floor(mi).astype(int), np.ceil(ma).astype(int)
def explicit_str_subband(x):
x = x.upper()
if '4LONG' in x:
return '4','LONG'
if '4MEDIUM' in x:
return '4','MEDIUM'
if '4SHORT' in x:
return '4','SHORT'
if '3LONG' in x:
return '3','LONG'
if '3MEDIUM' in x:
return '3','MEDIUM'
if '3SHORT' in x:
return '3','SHORT'
if '2LONG' in x:
return '2','LONG'
if '2MEDIUM' in x:
return '2','MEDIUM'
if '2SHORT' in x:
return '2','SHORT'
if '1LONG' in x:
return '1','LONG'
if '1MEDIUM' in x:
return '1','MEDIUM'
if '1SHORT' in x:
return '1','SHORT'
return 'ALL','ALL'
def b7cubesubband(x):
subbands = ['{}{}'.format(c,b) for c in range(1,5) for b in 'abc']
try:
chanband = x.split('.')[-3].split('_')[-1]
assert chanband[0]+chanband[1:-4] in subbands
return chanband[0], chanband[1:-4]
except Exception:
return explicit_str_subband(x)
class B7Photom(object):
def __init__(self, file):
self.file = file
self.model = model = dm.open(file)
self.w2ab = model.meta.wcs.get_transform('world', 'alpha_beta')
self.w2d = model.meta.wcs.get_transform('world', 'detector')
self.ab2d = model.meta.wcs.get_transform('alpha_beta', 'detector')
self.meta = model.meta
self.data = model.data
self.subbands = [c+model.meta.instrument.band for c in model.meta.instrument.channel]
w2v23 = model.meta.wcs.get_transform('world', 'v2v3')
v232ab = model.meta.wcs.get_transform('v2v3', 'alpha_beta')
try:
v2ab_channelA = v232ab.selector[1]
v2ab_channelB = v232ab.selector[2]
except:
v2ab_channelA = v232ab.selector[3]
v2ab_channelB = v232ab.selector[4]
self.w2ab_sb = {self.subbands[0]: w2v23 | v2ab_channelA,
self.subbands[1]: w2v23 | v2ab_channelB}
self.slicemap = model.meta.wcs.get_transform('detector', 'alpha_beta').label_mapper.mapper
class B7Cube(object):
def __init__(self, sfile):
self.file = sfile
channel, band = b7cubesubband(self.file)
self.channel = channel
self.band = band
model = self.model = dm.open(self.file)
self.wave = wave = model.meta.wcs([model.data.shape[2] / 2.0] * model.data.shape[0],
[model.data.shape[1] / 2.0] * model.data.shape[0],
np.arange(model.data.shape[0]))[2]
#ra/alpha
x = np.arange(model.data.shape[2])
yy = np.ones(model.data.shape[2])
y = yy * model.data.shape[1] / 2.0
z = yy * wave.mean()
self.ra = ra = model.meta.wcs(x, y, z)[0]
#dec/beta
y = np.arange(model.data.shape[1])
xx = np.ones(model.data.shape[1])
x = xx * model.data.shape[2] / 2.0
z = xx * wave.mean()
self.dec = dec = model.meta.wcs(x, y, z)[1]
try:
roll_ndx = np.where(np.diff(ra)<0)[0][0]+1
ra = np.concatenate((ra[:roll_ndx]-360, ra[roll_ndx:]))
except IndexError:
pass
self.alpha = ra * 3600
self.beta = dec * 3600
self.cube = model.data
self.wcs = model.meta.wcs
stcube = B7Cube('1Acube.fits')
stphots1 = B7Photom('12Aseq1.photom.fits')
stphots2 = B7Photom('12Aseq2.photom.fits')
stphots3 = B7Photom('12Aseq3.photom.fits')
stphots4 = B7Photom('12Aseq4.photom.fits')
stwave = stcube.wave
stde = stcube.dec
stra = stcube.ra
wave = np.array([5.30732315, 5.30750018, 5.30767721, 5.30785424, 5.30803127,
5.30820829, 5.30838532, 5.30856235, 5.30873938, 5.30891641,
5.30909344, 5.30927047, 5.3094475 , 5.30962452, 5.30980155,
5.30997858, 5.31015561, 5.31033264, 5.31050967, 5.3106867 ,
5.31086373, 5.31104076, 5.31121778, 5.31139481, 5.31157184,
5.31174887, 5.3119259 , 5.31210293, 5.31227996, 5.31245699,
5.31263401, 5.31281104, 5.31298807, 5.3131651 , 5.31334213,
5.31351916, 5.31369619, 5.31387322, 5.31405024, 5.31422727,
5.3144043 ])
wlo, whi = wavedom = nanminmax(wave)
ranges = []
immask = []
mesh=np.meshgrid(stra, stde, wave)
for i,im in enumerate([stphots1, stphots2, stphots3, stphots4]):
#This assumes that the wavelength domain will not cross channels in the image
#Compute XY locations of RA, Dec, Lambda on the detector and truncate to integer
xy=im.w2d(*mesh)
xyfinite=np.isfinite(xy[0])|np.isfinite(xy[1])
xy=xy[0][xyfinite].astype(int),xy[1][xyfinite].astype(int)
#Create a mask of the region
immask.append(np.zeros_like(im.data,bool))
immask[-1][xy[1], xy[0]]=1
ranges.append(list(map(pxrange, xy)))
plt.subplot(2,2,i+1)
normim = im.data.clip(50)/50 * (im.slicemap!=0)
plt.imshow(immask[-1]*2+(im.slicemap!=0)+normim,vmax=4)
@baileyji
Copy link
Author

baileyji commented Nov 1, 2017

from helpers import * may (and should) be deleted.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment