-
-
Save esheldon/2161c619f4a86b797408bada6222f320 to your computer and use it in GitHub Desktop.
Scarlet lite examples
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def run_scarlet( | |
obs, | |
cat, | |
max_iter=DEFAULT_MAX_ITER, | |
min_iter=DEFAULT_MIN_ITER, | |
e_rel=DEFAULT_E_REL, | |
reweight=False, | |
show=False | |
): | |
from numpy import floor | |
from functools import partial | |
import scarlet | |
from scarlet.lite import ( | |
init_all_sources_wavelets, LiteObservation, | |
integrated_circular_gaussian, init_fista_component, | |
parameterize_sources, LiteBlend, | |
) | |
flags = 0 | |
# Create the PSF model | |
model_psf = integrated_circular_gaussian(sigma=0.8) | |
model_psf = model_psf[None, :, :] | |
# Create a simple observation images, variance, and weights are all 3D | |
# (bands, y, x) psfs is the 3D image of the PSF (bands, y, x) | |
weights = obs.weight[None, :, :] | |
variance = 1/weights | |
psf_images = obs.psf.image[None, :, :] | |
images = obs.image[None, :, :] | |
observation = LiteObservation( | |
images=images, | |
variance=variance, | |
weights=weights, | |
psfs=psf_images, | |
model_psf=model_psf, | |
) | |
# Create a set of wavelet coefficients to use for initialization | |
wavelets = scarlet.detect.get_detect_wavelets(images, variance, scales=5) | |
# Given a list of peak centers, initialize the sources. | |
# This uses a similar algorithm to the main scarlet, | |
# where the signal to noise of the center is used to determine | |
# how many components each source can have. | |
# The bulge is the 1st two wavelet scales while the disk is the rest | |
# (by default). | |
# centers = [(y, x) for y, x in zip(cat['y'], cat['x'])] | |
def rp(x): | |
return int(floor(x+0.5)) | |
centers = [ | |
(rp(y), rp(x)) for y, x in zip(cat['y'], cat['x']) | |
] | |
sources = init_all_sources_wavelets(observation, wavelets, centers) | |
# This initializes the optimizer for each source. | |
# The background thresh is ~the ratio of the background RMS to | |
# use as a sparsity constraint, so flux below 10% of the background | |
# (in this example) is set to zero | |
sources = parameterize_sources( | |
sources, observation, partial(init_fista_component, bg_thresh=0.1) | |
) | |
# Create the blend and initialize with the best fit spectra for each source | |
# TODO how to check convergene? Just max_iter? | |
blend = LiteBlend(sources, observation).fit_spectra() | |
# I recommend at least 11 iterations to make sure that a blend doesn't | |
# converge too quickly and has at least one chance to resize | |
numiter, loglike = blend.fit( | |
max_iter=max_iter, e_rel=e_rel, min_iter=min_iter, | |
) | |
if numiter == max_iter: | |
logger.info(f'Scarlet reached max_iter {max_iter}') | |
flags = DEBLEND_FAIL | |
res = { | |
'flags': flags, | |
'numiter': numiter, | |
'loglike': loglike, | |
} | |
scarlet.cache.Cache._cache = {} | |
# optional (but probably not necessary for meta-cal) | |
# This step adds a `flux` attribute to each source so that | |
# `source.get_model()` gives the scarlet model and | |
# `source.flux` gives the "deblended" model by | |
# use all of the sources as templates and re-distributing the data | |
# according to the ratio of the templates (similar to the SDSS deblender). | |
if reweight: | |
scarlet.lite.weight_sources(blend) | |
if show: | |
show_all_scarlet( | |
observation, centers, blend, reweight=reweight, | |
) | |
return blend, res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def show_scarlet_full_blend(observation, blend, reweight=False): | |
import scarlet | |
# Compute model | |
model = blend.get_model(use_flux=reweight) | |
# Render it in the observed frame | |
model_image = observation.render(model) | |
# Compute residual | |
residual = observation.data - model_image | |
model_rgb = scarlet.display.img_to_rgb( | |
model_image, | |
norm=get_scarlet_norm(), | |
) | |
residual_rgb = scarlet.display.img_to_rgb(residual) | |
img_rgb = scarlet.display.img_to_rgb( | |
observation.data, | |
norm=get_scarlet_norm(), | |
) | |
# Show the data, model, and residual | |
fig = mplt.figure(figsize=(15, 5)) | |
ax = [fig.add_subplot(1, 3, n+1) for n in range(3)] | |
ax[0].imshow(img_rgb) | |
ax[0].set_title("Data") | |
ax[1].imshow(model_rgb) | |
ax[1].set_title("Model") | |
ax[2].imshow(residual_rgb) | |
ax[2].set_title("Residual") | |
# for k, src in enumerate(blend): | |
# if hasattr(src, "center"): | |
# y, x = src.center | |
# ax[0].text(x, y, k, color="w") | |
# ax[1].text(x, y, k, color="w") | |
# ax[2].text(x, y, k, color="w") | |
mplt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment