Skip to content

Instantly share code, notes, and snippets.

@settwi
Created June 22, 2023 01:10
Show Gist options
  • Save settwi/d704bb044cd09503787a6001b35c821c to your computer and use it in GitHub Desktop.
Save settwi/d704bb044cd09503787a6001b35c821c to your computer and use it in GitHub Desktop.
import astropy.units as u
import numpy as np
from sklearn import linear_model
import typing
@u.quantity_input
def decompose(
counts: u.ct,
energy_bins: u.keV,
thermal_energy: u.keV,
nonthermal_energy: u.keV,
norm_func: typing.Callable[[u.ct], u.one],
tolerance: float=0.05,
fit_intercept: bool=True
) -> np.ndarray:
'''
Decompose the given counts spectra into
"thermal" and "nonthermal" spectra using timing information
from the given target thermal/nonthermal energies.
"Tolerance" = "how far off can total counts be"
if you fit intercept, tolerance is useless
Returns dict of:
coefficients of decomposition (Nx2 array)
column 0: thermal coefs
column 1: nonthermal coefs
intercepts of decomposition (Nx1)
indices of (thermal, nonthermal) bands
'''
nearest = lambda a, v: np.abs(a - v).argmin()
mids = energy_bins[:-1] + np.diff(energy_bins)/2
th_idx = nearest(mids, thermal_energy)
nth_idx = nearest(mids, nonthermal_energy)
coefs = []
intercepts = []
predictor = np.array([
norm_func(counts[th_idx]),
norm_func(counts[nth_idx])
]).transpose()
for (i, raw_cts) in enumerate(counts):
target = norm_func(raw_cts)
if i == th_idx:
coefs.append([1, 0])
intercepts.append(0)
elif i == nth_idx:
coefs.append([0, 1])
intercepts.append(0)
else:
lr = linear_model.LinearRegression(
fit_intercept=fit_intercept, positive=True)
lr.fit(predictor, target)
coefs.append(lr.coef_)
intercepts.append(lr.intercept_)
coefs = np.array(coefs)
verify_coefs(coefs, tol=tolerance)
return {
'coefficients': coefs,
'intercepts': np.array(intercepts),
'indices': (th_idx, nth_idx)
}
def verify_coefs(c: np.ndarray, tol: float):
summed = c.sum(axis=1)
if not np.all(np.abs(summed - 1) < tol):
raise RuntimeError(f'some coefficient sums are not within {tol} of 1: {summed}')
@u.quantity_input
def sum_norm(cts: u.ct) -> u.one:
return cts / cts.sum()
@u.quantity_input
def minmax_norm(cts: u.ct) -> u.one:
min_ = cts.min()
max_ = cts.max()
return (cts - min_) / (max_ - min_)
@u.quantity_input
def mean_norm(cts: u.ct) -> u.one:
return cts / cts.mean()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment