Source code for tolteca.simu.mapping.base

#!/usr/bin/env python


from astropy.modeling import Model
from gwcs import coordinate_frames as cf
import astropy.units as u
from astropy.coordinates import SkyCoord
from tollan.utils.log import timeit
from astropy.table import Table
from astropy.utils import indent
from enum import Flag, auto

from .utils import _get_skyoffset_frame, resolve_sky_coords_frame


__all__ = ['OffsetMappingModel', 'TargetedOffsetMappingModel']


[docs]class PatternKind(Flag): """The available pattern kinds.""" raster = auto() lissajous = auto() rastajous = auto() raster_like = raster | rastajous
class MappingModelBase(Model): """The base class for all mapping models.""" def pattern_kind(self): return self._pattern_kind class OffsetMappingModel(Model): """The base class for mapping models defined in offset coordinates.""" n_inputs = 1 n_outputs = 2 fittable = False def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.name = self.pattern_name self.inputs = ('t', ) self.outputs = self.frame.axes_names def get_traj_model(self, **kwargs): return TargetedOffsetMappingModel( offset_mapping_model=self, **kwargs) class TrajMappingModel(Model): """The base class for mapping models in absolute coordinates.""" n_inputs = 1 n_outputs = 2 fittable = False def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.inputs = ('t', ) class TargetedOffsetMappingModel(TrajMappingModel): """The class for trajectories rendered using `OffsetMappingModel` around a target. """ def __init__( self, offset_mapping_model, target, ref_frame=None, t0=None, observer=None ): if ref_frame is None: ref_frame = target.frame # TODO maybe we need not hard-code the units here. frame = cf.CelestialFrame( name=ref_frame.name, reference_frame=ref_frame, unit=(u.deg, u.deg) ) self._offset_mapping_model = offset_mapping_model self._target = target self._ref_frame = ref_frame self._frame = frame self._t0 = t0 self._observer = observer super().__init__(name=f'targeted_{offset_mapping_model.name}') self.outputs = self.frame.axes_names @property def offset_mapping_model(self): """The offset model.""" return self._offset_mapping_model @property def target(self): """The target coordinates.""" return self._target @property def ref_frame(self): """The reference frame with respect to which the offsets are interpreted.""" return self._ref_frame @property def frame(self): """The WCS coordinate frame for `ref_frame`.""" return self._frame @property def t0(self): """The start time (in UT) of the mapping.""" return self._t0 @property def observer(self): """The observer of the mapping.""" return self._observer @property def t_pattern(self): return self.offset_mapping_model.t_pattern @property def pattern_kind(self): return self.offset_mapping_model.pattern_kind @timeit def evaluate_coords(self, t): """Return the mapping pattern coordinates as evaluated at target. """ time_obs = self.t0 + t ref_frame = resolve_sky_coords_frame( self.ref_frame, observer=self.observer, time_obs=time_obs) target_in_ref_frame = self.target.transform_to(ref_frame) frame = _get_skyoffset_frame(target_in_ref_frame) lon, lat = self.offset_mapping_model(t) return SkyCoord(lon, lat, frame=frame).transform_to(ref_frame) @timeit def evaluate(self, t): coords = self.evaluate_coords(t) return coords.data.lon.to(u.deg), coords.data.lat.to(u.deg) @timeit def evaluate_holdflag(self, t): return self.offset_mapping_model.evaluate_holdflag(t) def __repr__(self): return ( f'{self.__class__.__name__}({self.offset_mapping_model}, ' f'target={self.target}, t0={self.t0})' ) def __str__(self): default_keywords = [ ('Model', self.__class__.__name__), ('Name', self.name), ('Inputs', self.inputs), ('Outputs', self.outputs), ('Model set size', len(self)), ("Target", self.target), ("Reference frame", self.ref_frame.__class__.__name__), ("Time_0", self.t0), ] parts = [f'{keyword}: {value}' for keyword, value in default_keywords if value is not None] parts.append('Offset Model Parameters:') m = self.offset_mapping_model if len(m) == 1: columns = [[getattr(m, name).value] for name in m.param_names] else: columns = [getattr(m, name).value for name in m.param_names] if columns: param_table = Table(columns, names=m.param_names) # Set units on the columns for name in m.param_names: param_table[name].unit = getattr(m, name).unit parts.append(indent(str(param_table), width=4)) return '\n'.join(parts)