Skip to content

Instantly share code, notes, and snippets.

@esheldon
Created November 4, 2021 13:56
Show Gist options
  • Save esheldon/2161c619f4a86b797408bada6222f320 to your computer and use it in GitHub Desktop.
Save esheldon/2161c619f4a86b797408bada6222f320 to your computer and use it in GitHub Desktop.
Scarlet lite examples
def run_scarlet(
obs,
cat,
max_iter=DEFAULT_MAX_ITER,
min_iter=DEFAULT_MIN_ITER,
e_rel=DEFAULT_E_REL,
reweight=False,
show=False
):
from numpy import floor
from functools import partial
import scarlet
from scarlet.lite import (
init_all_sources_wavelets, LiteObservation,
integrated_circular_gaussian, init_fista_component,
parameterize_sources, LiteBlend,
)
flags = 0
# Create the PSF model
model_psf = integrated_circular_gaussian(sigma=0.8)
model_psf = model_psf[None, :, :]
# Create a simple observation images, variance, and weights are all 3D
# (bands, y, x) psfs is the 3D image of the PSF (bands, y, x)
weights = obs.weight[None, :, :]
variance = 1/weights
psf_images = obs.psf.image[None, :, :]
images = obs.image[None, :, :]
observation = LiteObservation(
images=images,
variance=variance,
weights=weights,
psfs=psf_images,
model_psf=model_psf,
)
# Create a set of wavelet coefficients to use for initialization
wavelets = scarlet.detect.get_detect_wavelets(images, variance, scales=5)
# Given a list of peak centers, initialize the sources.
# This uses a similar algorithm to the main scarlet,
# where the signal to noise of the center is used to determine
# how many components each source can have.
# The bulge is the 1st two wavelet scales while the disk is the rest
# (by default).
# centers = [(y, x) for y, x in zip(cat['y'], cat['x'])]
def rp(x):
return int(floor(x+0.5))
centers = [
(rp(y), rp(x)) for y, x in zip(cat['y'], cat['x'])
]
sources = init_all_sources_wavelets(observation, wavelets, centers)
# This initializes the optimizer for each source.
# The background thresh is ~the ratio of the background RMS to
# use as a sparsity constraint, so flux below 10% of the background
# (in this example) is set to zero
sources = parameterize_sources(
sources, observation, partial(init_fista_component, bg_thresh=0.1)
)
# Create the blend and initialize with the best fit spectra for each source
# TODO how to check convergene? Just max_iter?
blend = LiteBlend(sources, observation).fit_spectra()
# I recommend at least 11 iterations to make sure that a blend doesn't
# converge too quickly and has at least one chance to resize
numiter, loglike = blend.fit(
max_iter=max_iter, e_rel=e_rel, min_iter=min_iter,
)
if numiter == max_iter:
logger.info(f'Scarlet reached max_iter {max_iter}')
flags = DEBLEND_FAIL
res = {
'flags': flags,
'numiter': numiter,
'loglike': loglike,
}
scarlet.cache.Cache._cache = {}
# optional (but probably not necessary for meta-cal)
# This step adds a `flux` attribute to each source so that
# `source.get_model()` gives the scarlet model and
# `source.flux` gives the "deblended" model by
# use all of the sources as templates and re-distributing the data
# according to the ratio of the templates (similar to the SDSS deblender).
if reweight:
scarlet.lite.weight_sources(blend)
if show:
show_all_scarlet(
observation, centers, blend, reweight=reweight,
)
return blend, res
def show_scarlet_full_blend(observation, blend, reweight=False):
import scarlet
# Compute model
model = blend.get_model(use_flux=reweight)
# Render it in the observed frame
model_image = observation.render(model)
# Compute residual
residual = observation.data - model_image
model_rgb = scarlet.display.img_to_rgb(
model_image,
norm=get_scarlet_norm(),
)
residual_rgb = scarlet.display.img_to_rgb(residual)
img_rgb = scarlet.display.img_to_rgb(
observation.data,
norm=get_scarlet_norm(),
)
# Show the data, model, and residual
fig = mplt.figure(figsize=(15, 5))
ax = [fig.add_subplot(1, 3, n+1) for n in range(3)]
ax[0].imshow(img_rgb)
ax[0].set_title("Data")
ax[1].imshow(model_rgb)
ax[1].set_title("Model")
ax[2].imshow(residual_rgb)
ax[2].set_title("Residual")
# for k, src in enumerate(blend):
# if hasattr(src, "center"):
# y, x = src.center
# ax[0].text(x, y, k, color="w")
# ax[1].text(x, y, k, color="w")
# ax[2].text(x, y, k, color="w")
mplt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment