Created
June 22, 2023 01:10
-
-
Save settwi/d704bb044cd09503787a6001b35c821c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import astropy.units as u | |
import numpy as np | |
from sklearn import linear_model | |
import typing | |
@u.quantity_input | |
def decompose( | |
counts: u.ct, | |
energy_bins: u.keV, | |
thermal_energy: u.keV, | |
nonthermal_energy: u.keV, | |
norm_func: typing.Callable[[u.ct], u.one], | |
tolerance: float=0.05, | |
fit_intercept: bool=True | |
) -> np.ndarray: | |
''' | |
Decompose the given counts spectra into | |
"thermal" and "nonthermal" spectra using timing information | |
from the given target thermal/nonthermal energies. | |
"Tolerance" = "how far off can total counts be" | |
if you fit intercept, tolerance is useless | |
Returns dict of: | |
coefficients of decomposition (Nx2 array) | |
column 0: thermal coefs | |
column 1: nonthermal coefs | |
intercepts of decomposition (Nx1) | |
indices of (thermal, nonthermal) bands | |
''' | |
nearest = lambda a, v: np.abs(a - v).argmin() | |
mids = energy_bins[:-1] + np.diff(energy_bins)/2 | |
th_idx = nearest(mids, thermal_energy) | |
nth_idx = nearest(mids, nonthermal_energy) | |
coefs = [] | |
intercepts = [] | |
predictor = np.array([ | |
norm_func(counts[th_idx]), | |
norm_func(counts[nth_idx]) | |
]).transpose() | |
for (i, raw_cts) in enumerate(counts): | |
target = norm_func(raw_cts) | |
if i == th_idx: | |
coefs.append([1, 0]) | |
intercepts.append(0) | |
elif i == nth_idx: | |
coefs.append([0, 1]) | |
intercepts.append(0) | |
else: | |
lr = linear_model.LinearRegression( | |
fit_intercept=fit_intercept, positive=True) | |
lr.fit(predictor, target) | |
coefs.append(lr.coef_) | |
intercepts.append(lr.intercept_) | |
coefs = np.array(coefs) | |
verify_coefs(coefs, tol=tolerance) | |
return { | |
'coefficients': coefs, | |
'intercepts': np.array(intercepts), | |
'indices': (th_idx, nth_idx) | |
} | |
def verify_coefs(c: np.ndarray, tol: float): | |
summed = c.sum(axis=1) | |
if not np.all(np.abs(summed - 1) < tol): | |
raise RuntimeError(f'some coefficient sums are not within {tol} of 1: {summed}') | |
@u.quantity_input | |
def sum_norm(cts: u.ct) -> u.one: | |
return cts / cts.sum() | |
@u.quantity_input | |
def minmax_norm(cts: u.ct) -> u.one: | |
min_ = cts.min() | |
max_ = cts.max() | |
return (cts - min_) / (max_ - min_) | |
@u.quantity_input | |
def mean_norm(cts: u.ct) -> u.one: | |
return cts / cts.mean() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment