#!/usr/bin/env python3
"""Post process DR1 source catalog"""
from __future__ import annotations
import logging
import os
from pathlib import Path
import numpy as np
from arrakis.logger import TqdmToLogger, logger
from arrakis.makecat import fix_blank_units, replace_nans, vot
from astropy.coordinates import SkyCoord
from astropy.table import Table
from tqdm.auto import tqdm
[docs]
TQDM_OUT = TqdmToLogger(logger, level=logging.INFO)
logger.setLevel("DEBUG")
[docs]
def write_votable(table: Table, outfile: str) -> None:
# Replace bad column names
fix_columns = {
"catalog": "catalog_name",
"interval": "obs_interval",
}
# CASDA needs v1.3
for col_name, new_name in fix_columns.items():
if col_name in table.colnames:
table.rename_column(col_name, new_name)
# Fix blank units
table = fix_blank_units(table)
vo_table = vot.from_table(table)
vo_table.version = "1.3"
vo_table = add_metadata(vo_table, table, outfile)
vot.writeto(vo_table, outfile)
# Fix NaNs for CASDA
replace_nans(outfile)
[docs]
def main(
source_cat_pth: Path,
spice_cat_pth: Path,
survey_dir: Path,
epoch: int = 0,
):
logger.info(f"Loading source catalog from {source_cat_pth}")
source_cat = Table.read(source_cat_pth)
logger.info(f"Loading SPICE catalog from {spice_cat_pth}")
spice_cat = Table.read(spice_cat_pth)
field_path = survey_dir / "db" / f"epoch_{epoch}" / "field_data.csv"
logger.info(f"Loading field data from {field_path}")
field = Table.read(field_path)
field = field[field["SELECT"] == 1]
field.add_index("FIELD_NAME")
spice_df = spice_cat.to_pandas()
spice_df.set_index("source_id", inplace=True)
spice_df.sort_index(inplace=True)
source_cat.sort("Source_ID")
spice_grp = spice_df.groupby("source_id")
sources = list(set(spice_cat["source_id"]))
src_idx = np.isin(source_cat["Source_ID"].data.astype(str), sources)
source_cut = source_cat[src_idx]
columns_to_fix = {
"SBID": "sbid",
"Obs_Start_Time": "start_time",
"Tile_ID": "tile_id",
}
for source_col, spice_col in columns_to_fix.items():
logger.info(f"Fixing {source_col} with {spice_col}")
source_cut[source_col] = spice_grp[spice_col].apply(lambda x: x[0])
# Fix the separations
logger.info("Fixing separations to field centres")
field_coords = SkyCoord(field["RA_DEG"], field["DEC_DEG"], unit="deg", frame="icrs")
source_coords = SkyCoord(
source_cut["RA"], source_cut["Dec"], unit="deg", frame="icrs"
)
tile_ids = np.unique(source_cut["Tile_ID"].data.astype(str))
for tile_id in tqdm(tile_ids, file=TQDM_OUT):
field_tile_idx = field["FIELD_NAME"] == tile_id
source_tile_idx = source_cut["Tile_ID"] == tile_id
tile_coords = field_coords[field_tile_idx]
seps = source_coords[source_tile_idx].separation(tile_coords)
source_cut["Separation_Tile_Centre"][source_tile_idx] = seps.deg
# Fix bad columns
bad_cols = {
"Min": "Min_axis",
"Maj": "Maj_axis",
"E_Min": "E_Min_axis",
"E_Maj": "E_Maj_axis",
}
for bad_col, good_col in bad_cols.items():
if bad_col in source_cut.colnames:
logger.info(f"Renaming {bad_col} to {good_col}")
source_cut.rename_column(bad_col, good_col)
else:
logger.warning(f"{bad_col} not in source catalog")
# Write the output cat
out_pth = spice_cat_pth.parent / ("RACS_DR1_Sources_" + spice_cat_pth.name)
logger.info(f"Writing corrected catalogue to {out_pth}")
write_votable(source_cut, out_pth.as_posix())
[docs]
def cli():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"source",
type=str,
help="RACS Source catalog",
)
parser.add_argument(
"spice",
type=str,
help="SPICE-RACS catalog",
)
parser.add_argument(
"survey",
type=str,
help="Survey directory",
)
parser.add_argument(
"--epoch",
type=int,
default=0,
help="Epoch to read field data from",
)
args = parser.parse_args()
main(
source_cat_pth=Path(args.source),
spice_cat_pth=Path(args.spice),
survey_dir=Path(args.survey),
epoch=args.epoch,
)
if __name__ == "__main__":
cli()