Skip to content

Instantly share code, notes, and snippets.

@taldcroft
Created May 6, 2022 11:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save taldcroft/44a88f079f4afe85ae2a14c75ac2795d to your computer and use it in GitHub Desktop.
Save taldcroft/44a88f079f4afe85ae2a14c75ac2795d to your computer and use it in GitHub Desktop.
Table subclass with parameters
import numpy as np
from astropy.table import Table, Column
from astropy.table.ndarray_mixin import NdarrayMixin, NdarrayMixinInfo
from astropy.io.registry import UnifiedReadWriteMethod
from astropy.table.connect import TableRead, TableWrite
from astropy.table.info import serialize_method_as
def fmt_func(val):
return f'{val["par"]:.2f} ({val["pmn"]:.2f}, {val["pmx"]:.2f})'
PAR_DTYPE = np.dtype([('par', 'f8'), ('pmn', 'f8'), ('pmx', 'f8')])
class ParTableRead(TableRead):
def __call__(self, *args, **kwargs):
out = super().__call__(*args, **kwargs)
for col in out.columns.values():
if col.dtype == PAR_DTYPE:
col.info.format = fmt_func
return out
class ParTableWrite(TableWrite):
"""Something like this SHOULD work, but there is something obscure and
weird going on that makes this fail. So instead just override the
write() method directly."""
def __call__(self, *args, serialize_method=None, **kwargs):
instance = self._instance
par_cols = []
try:
for col in instance.columns.values():
if col.info.format is fmt_func:
par_cols.append(col.info.name)
col.info.format = None
return super().__call__(*args, **kwargs)
finally:
for par_col in par_cols:
self[par_col].info.format = fmt_func
class ParTable(Table):
read = UnifiedReadWriteMethod(ParTableRead)
# write = UnifiedReadWriteMethod(ParTableWrite)
def _convert_data_to_col(self, *args, **kwargs):
col = super()._convert_data_to_col(*args, **kwargs)
if col.dtype == PAR_DTYPE:
col.info.format = fmt_func
return col
def write(self, *args, **kwargs):
new_self = Table(self)
for col in new_self.columns.values():
if col.info.format is fmt_func:
col.info.format = None
return new_self.write(*args, **kwargs)
name = ['par1', 'par2']
a = np.array([(np.pi, 2, 4.5), (np.pi / 2, 1, 3.1)], dtype=PAR_DTYPE)
t = ParTable([name, a], names=['name', 'par'])
t.write('pars.ecsv', overwrite=True)
t2 = ParTable.read('pars.ecsv')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment