Source code for arrakis.cutout

#!/usr/bin/env python
"""Produce cutouts from RACS cubes"""

from __future__ import annotations

import argparse
import logging
import warnings
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from pprint import pformat
from shutil import copyfile
from threading import Lock
from typing import NamedTuple as Struct
from typing import TypeVar

import astropy.units as u
import numpy as np
import pandas as pd
import pymongo
from astropy.coordinates import SkyCoord
from astropy.io import fits
from astropy.utils import iers
from astropy.utils.exceptions import AstropyWarning
from astropy.wcs.utils import skycoord_to_pixel
from prefect import flow, task
from prefect.task_runners import ConcurrentTaskRunner
from spectral_cube import SpectralCube
from spectral_cube.utils import SpectralCubeWarning
from tqdm.auto import tqdm

from arrakis.logger import TqdmToLogger, UltimateHelpFormatter, logger
from arrakis.utils.database import (
    get_db,
    get_field_db,
    test_db,
    validate_sbid_field_pair,
)
from arrakis.utils.fitsutils import fix_header
from arrakis.utils.pipeline import generic_parser, logo_str, workdir_arg_parser

iers.conf.auto_download = False
warnings.filterwarnings(
    "ignore", message="Cube is a Stokes cube, returning spectral cube for I component"
)

warnings.filterwarnings(action="ignore", category=SpectralCubeWarning, append=True)
warnings.simplefilter("ignore", category=AstropyWarning)
warnings.filterwarnings("ignore", message="invalid value encountered in true_divide")

logger.setLevel(logging.INFO)
[docs] TQDM_OUT = TqdmToLogger(logger, level=logging.INFO)
[docs] T = TypeVar("T")
[docs] class CutoutArgs(Struct): """Arguments for cutout function""" """Name of the source"""
[docs] ra_left: float
"""Upper RA bound in degrees"""
[docs] ra_right: float
"""Lower RA bound in degrees"""
[docs] dec_high: float
"""Upper DEC bound in degrees"""
[docs] dec_low: float
"""Lower DEC bound in degrees"""
[docs] outdir: Path
"""Output directory"""
[docs] def cutout_weight( image_name: Path, source_id: str, cutout_args: CutoutArgs | None, field: str, stoke: str, beam_num: int, dryrun=False, ) -> pymongo.UpdateOne: # Update database myquery = {"Source_ID": source_id} if cutout_args is None: logger.error(f"Skipping {source_id} -- no components found") newvalues = { "$set": {f"beams.{field}.{stoke.lower()}_beam{beam_num}_weight_file": ""} } return pymongo.UpdateOne(myquery, newvalues, upsert=True) outdir = cutout_args.outdir.absolute() basename = image_name.name outname = f"{source_id}.cutout.{basename}" outfile = outdir / outname image = ( image_name.parent / image_name.name.replace("image.restored", "weights.restored") ).with_suffix(".txt") outfile = ( outfile.parent / outfile.name.replace("image.restored", "weights.restored") ).with_suffix(".txt") if not dryrun: copyfile(image, outfile) logger.info(f"Written to {outfile}") filename = outfile.parent / outfile.name newvalues = { "$set": { f"beams.{field}.{stoke.lower()}_beam{beam_num}_weight_file": filename.absolute().as_posix() } } return pymongo.UpdateOne(myquery, newvalues, upsert=True)
[docs] def cutout_image( lock: Lock, image_name: Path, data_in_mem: np.ndarray, old_header: fits.Header, cube: SpectralCube, source_id: str, cutout_args: CutoutArgs | None, field: str, beam_num: int, stoke: str, pad: float = 3, dryrun: bool = False, ) -> pymongo.UpdateOne: """Perform a cutout. Returns: pymongo.UpdateOne: Update query for MongoDB """ logger.setLevel(logging.INFO) # Update database myquery = {"Source_ID": source_id} if cutout_args is None: logger.error(f"Skipping {source_id} -- no components found") newvalues = { "$set": {f"beams.{field}.{stoke.lower()}_beam{beam_num}_weight_file": ""} } return pymongo.UpdateOne(myquery, newvalues, upsert=True) outdir = cutout_args.outdir.absolute() basename = image_name.name outname = f"{source_id}.cutout.{basename}" outfile = outdir / outname padder: float = cube.header["BMAJ"] * u.deg * pad top_right = SkyCoord(cutout_args.ra_right * u.deg, cutout_args.dec_high * u.deg) bottom_left = SkyCoord(cutout_args.ra_left * u.deg, cutout_args.dec_low * u.deg) top_left = SkyCoord(cutout_args.ra_left * u.deg, cutout_args.dec_high * u.deg) bottom_right = SkyCoord(cutout_args.ra_right * u.deg, cutout_args.dec_low * u.deg) top_right_off = top_right.directional_offset_by( position_angle=-45 * u.deg, separation=padder * np.sqrt(2), ) bottom_left_off = bottom_left.directional_offset_by( position_angle=135 * u.deg, separation=padder * np.sqrt(2), ) top_left_off = top_left.directional_offset_by( position_angle=45 * u.deg, separation=padder * np.sqrt(2), ) bottom_right_off = bottom_right.directional_offset_by( position_angle=-135 * u.deg, separation=padder * np.sqrt(2), ) x_left, y_bottom = skycoord_to_pixel(bottom_left_off, cube.wcs.celestial) x_right, y_top = skycoord_to_pixel(top_right_off, cube.wcs.celestial) _x_left, _y_top = skycoord_to_pixel(top_left_off, cube.wcs.celestial) _x_right, _y_bottom = skycoord_to_pixel(bottom_right_off, cube.wcs.celestial) # Compare all points in case of insanity at the poles yp_lo_idx = int(np.floor(min(y_bottom, _y_bottom, y_top, _y_top))) yp_hi_idx = int(np.ceil(max(y_bottom, _y_bottom, y_top, _y_top))) xp_lo_idx = int(np.floor(min(x_left, x_right, _x_left, _x_right))) xp_hi_idx = int(np.ceil(max(x_left, x_right, _x_left, _x_right))) # Use subcube for header transformation cutout_cube = cube[:, yp_lo_idx:yp_hi_idx, xp_lo_idx:xp_hi_idx] new_header = cutout_cube.header sub_data = data_in_mem[ :, :, yp_lo_idx:yp_hi_idx, xp_lo_idx:xp_hi_idx, # freq, Stokes, y, x ] fixed_header = fix_header(new_header, old_header) # Add source name to header for CASDA fixed_header["OBJECT"] = source_id if not dryrun: with lock: fits.writeto( outfile, sub_data, header=fixed_header, overwrite=True, output_verify="fix", ) logger.info(f"Written to {outfile}") filename = outfile.parent / outfile.name newvalues = { "$set": { f"beams.{field}.{stoke.lower()}_beam{beam_num}_image_file": filename.absolute().as_posix() } } return pymongo.UpdateOne(myquery, newvalues, upsert=True)
[docs] def get_args( comps: pd.DataFrame, source: pd.Series, outdir: Path, ) -> CutoutArgs | None: """Get arguments for cutout function Args: comps (pd.DataFrame): List of mongo entries for RACS components in island beam (Dict): Mongo entry for the RACS beam island_id (str): RACS island ID outdir (Path): Input directory Raises: e: Exception Exception: Problems with coordinates Returns: List[CutoutArgs]: List of cutout arguments for cutout function """ logger.setLevel(logging.INFO) island_id = source.Source_ID if len(comps) == 0: logger.warning(f"Skipping island {island_id} -- no components found") return None outdir = outdir / island_id outdir.mkdir(parents=True, exist_ok=True) # Find image size ras: u.Quantity = comps.RA.values * u.deg decs: u.Quantity = comps.Dec.values * u.deg majs: list[float] = comps.Maj.values * u.arcsec coords = SkyCoord(ras, decs, frame="icrs") padder = np.max(majs) try: # If in South - Northern boundary will have greatest RA range # And vice versa northern_boundary = SkyCoord(ras, np.ones_like(ras.value) * decs.max()) southern_boundary = SkyCoord(ras, np.ones_like(ras.value) * decs.min()) if np.abs(decs.max()) > np.abs(decs.min()): # North - use Southern longest_boundary = southern_boundary else: # South - use Northern longest_boundary = northern_boundary # Compute separations and position angles between all pairs of coordinates separations = longest_boundary[:, np.newaxis].separation(longest_boundary) position_angles = longest_boundary[:, np.newaxis].position_angle( longest_boundary ) # Find the index of the pair with the largest separation largest_i, largest_j = np.unravel_index( np.argmax(separations), separations.shape ) # Determine left and right points based on position angle if position_angles[largest_i, largest_j] > 180 * u.deg: left_point = longest_boundary[largest_i] right_point = longest_boundary[largest_j] else: left_point = longest_boundary[largest_j] right_point = longest_boundary[largest_i] # Compute coordinates for top right, top left, bottom right, and bottom left points top_right = SkyCoord(right_point.ra, decs.max()) top_left = SkyCoord(left_point.ra, decs.max()) bottom_right = SkyCoord(right_point.ra, decs.min()) bottom_left = SkyCoord(left_point.ra, decs.min()) # Compute offsets for top right, top left, bottom right, and bottom left points offset = padder * np.sqrt(2) top_right_off = top_right.directional_offset_by( position_angle=-45 * u.deg, separation=offset ) bottom_left_off = bottom_left.directional_offset_by( position_angle=135 * u.deg, separation=offset ) top_left_off = top_left.directional_offset_by( position_angle=45 * u.deg, separation=offset ) bottom_right_off = bottom_right.directional_offset_by( position_angle=-135 * u.deg, separation=offset ) # Compute dec_low, dec_high, ra_right, and ra_left dec_low = min(bottom_left_off.dec.value, bottom_right_off.dec.value) * u.deg dec_high = max(top_left_off.dec.value, top_right_off.dec.value) * u.deg ra_right = top_right_off.ra ra_left = top_left_off.ra except Exception as e: logger.debug(f"coords are {coords=}") logger.debug(f"comps are {comps=}") raise e return CutoutArgs( ra_left=ra_left.to(u.deg).value, ra_right=ra_right.to(u.deg).value, dec_high=dec_high.to(u.deg).value, dec_low=dec_low.to(u.deg).value, outdir=outdir, )
[docs] def make_cutout( lock: Lock, host: str, epoch: int, source: pd.Series, comps: pd.DataFrame, outdir: Path, image_name: Path, data_in_mem: np.ndarray, old_header: fits.Header, cube: SpectralCube, field: str, beam_num: int, stoke: str, pad: float = 3, username: str | None = None, password: str | None = None, dryrun: bool = False, ): _, _, comp_col = get_db( host=host, epoch=epoch, username=username, password=password ) cut_args = get_args( comps=comps, source=source, outdir=outdir, ) image_update = cutout_image( lock=lock, image_name=image_name, data_in_mem=data_in_mem, old_header=old_header, cube=cube, source_id=source.Source_ID, cutout_args=cut_args, field=field, beam_num=beam_num, stoke=stoke, pad=pad, dryrun=dryrun, ) weight_update = cutout_weight( image_name=image_name, source_id=source.Source_ID, cutout_args=cut_args, field=field, beam_num=beam_num, stoke=stoke, dryrun=dryrun, ) return [image_update, weight_update]
@task( name="Cutout from big cube", retries=3, retry_delay_seconds=[1, 10, 100], persist_result=True, )
[docs] def big_cutout( sources: pd.DataFrame, comps: pd.DataFrame, beam_num: int, stoke: str, datadir: Path, outdir: Path, host: str, epoch: int, field: str, pad: float = 3, username: str | None = None, password: str | None = None, limit: int | None = None, dryrun: bool = False, ) -> list[pymongo.UpdateOne]: wild = f"image.restored.{stoke.lower()}*contcube*beam{beam_num:02}.conv.fits" images = list(datadir.glob(wild)) if len(images) == 0: raise Exception(f"No images found matching '{wild}'") elif len(images) > 1: raise Exception(f"More than one image found matching '{wild}'. Files {images=}") image_name = images[0] # Read the whole lad into memory logger.info(f"Reading {image_name}") with warnings.catch_warnings(): warnings.simplefilter("ignore", AstropyWarning) cube = SpectralCube.read(image_name, memmap=True, mode="denywrite") data_in_mem = np.array(fits.getdata(image_name)) old_header = fits.getheader(image_name) if limit is not None: logger.critical(f"Limiting to {limit} islands") sources = sources[:limit] # Check for slurm cpus updates: list[pymongo.UpdateOne] = [] lock = Lock() with ThreadPoolExecutor() as executor: futures = [] for _, source in sources.iterrows(): futures.append( executor.submit( make_cutout, lock=lock, host=host, epoch=epoch, source=source, comps=comps.loc[source], outdir=outdir, image_name=image_name, data_in_mem=data_in_mem, old_header=old_header, cube=cube, field=field, beam_num=beam_num, stoke=stoke, pad=pad, username=username, password=password, dryrun=dryrun, ) ) for future in tqdm(futures, file=TQDM_OUT, desc=f"Cutting {image_name}"): updates += future.result() return updates
@flow(name="Cutout islands")
[docs] def cutout_islands( field: str, directory: Path, host: str, epoch: int, sbid: int | None = None, username: str | None = None, password: str | None = None, pad: float = 3, stokeslist: list[str] | None = None, dryrun: bool = True, limit: int | None = None, ) -> None: """Flow to cutout islands in parallel. Args: field (str): RACS field name. directory (Path): Directory to store cutouts. host (str): MongoDB host. username (str, optional): Mongo username. Defaults to None. password (str, optional): Mongo password. Defaults to None. pad (int, optional): Number of beamwidths to pad cutouts. Defaults to 3. stokeslist (List[str], optional): Stokes parameters to cutout. Defaults to None. dryrun (bool, optional): Do everything except write FITS files. Defaults to True. """ if stokeslist is None: stokeslist = ["I", "Q", "U", "V"] directory = directory.absolute() outdir = directory / "cutouts" logger.info("Testing database. ") test_db( host=host, username=username, password=password, ) beams_col, island_col, comp_col = get_db( host=host, epoch=epoch, username=username, password=password ) # Check for SBID match if sbid is not None: field_col = get_field_db( host=host, epoch=epoch, username=username, password=password, ) sbid_check = validate_sbid_field_pair( field_name=field, sbid=sbid, field_col=field_col, ) if not sbid_check: raise ValueError(f"SBID {sbid} does not match field {field}") query = {"$and": [{f"beams.{field}": {"$exists": True}}]} if sbid is not None: query["$and"].append({f"beams.{field}.SBIDs": sbid}) unique_beams_nums: set[int] = set( beams_col.distinct(f"beams.{field}.beam_list", query) ) source_ids = sorted(beams_col.distinct("Source_ID", query)) # beams_dict: Dict[int, List[Dict]] = {b: [] for b in unique_beams_nums} query = { "$and": [ {f"beams.{field}": {"$exists": True}}, {f"beams.{field}.beam_list": {"$in": list(unique_beams_nums)}}, ] } if sbid is not None: query["$and"].append({f"beams.{field}.SBIDs": sbid}) beams_df = pd.DataFrame( beams_col.find(query, {"Source_ID": 1, f"beams.{field}.beam_list": 1}).sort( "Source_ID" ) ) beam_source_list = [] for i, row in tqdm(beams_df.iterrows()): beam_list = row.beams[field]["beam_list"] for b in beam_list: beam_source_list.append({"Source_ID": row.Source_ID, "beam": b}) beam_source_df = pd.DataFrame(beam_source_list) beam_source_df.set_index("beam", inplace=True) comps_df = pd.DataFrame( comp_col.find({"Source_ID": {"$in": source_ids}}).sort("Source_ID") ) comps_df.set_index("Source_ID", inplace=True) # Create output dir if it doesn't exist outdir.mkdir(parents=True, exist_ok=True) cuts: list[pymongo.UpdateOne] = [] for stoke in stokeslist: for beam_num in unique_beams_nums: # Force the DataFrame type in the rare case of a single # source / component in a beam sources = beam_source_df.loc[beam_num] comps = comps_df.loc[sources.Source_ID] if isinstance(sources, pd.Series): sources = pd.DataFrame(sources).T results = big_cutout.submit( sources=sources, comps=comps, beam_num=beam_num, stoke=stoke, datadir=directory, outdir=outdir, host=host, epoch=epoch, field=field, pad=pad, username=username, password=password, limit=limit, dryrun=dryrun, ) cuts.append(results) if not dryrun: _updates = [f.result() for f in cuts] updates = [val for sublist in _updates for val in sublist] logger.info("Updating database...") db_res = beams_col.bulk_write(updates, ordered=False) logger.info(pformat(db_res.bulk_api_result)) logger.info("Cutouts Done!")
[docs] def main(args: argparse.Namespace) -> None: """Main script Args: args (argparse.Namespace): Command-line args """ cutout_islands.with_options(task_runner=ConcurrentTaskRunner)( field=args.field, directory=args.datadir, host=args.host, epoch=args.epoch, sbid=args.sbid, username=args.username, password=args.password, pad=args.pad, stokeslist=args.stokeslist, dryrun=args.dryrun, limit=args.limit, ) logger.info("Done!")
[docs] def cutout_parser(parent_parser: bool = False) -> argparse.ArgumentParser: descStr = f""" {logo_str} Arrakis Stage 1: Produce cubelets from a RACS field using a Selavy table. If Stokes V is present, it will be squished into RMS spectra. To use with MPI: mpirun -n $NPROCS python -u cutout.py $cubedir $tabledir $outdir --mpi """ # Parse the command line options cut_parser = argparse.ArgumentParser( add_help=not parent_parser, description=descStr, formatter_class=UltimateHelpFormatter, ) parser = cut_parser.add_argument_group("cutout arguments") parser.add_argument( "-p", "--pad", dest="pad", type=float, default=3, help="Number of beamwidths to pad around source [3].", ) parser.add_argument( "-d", "--dryrun", action="store_true", help="Do a dry-run [False]." ) return cut_parser
[docs] def cli() -> None: """Command-line interface""" gen_parser = generic_parser(parent_parser=True) work_parser = workdir_arg_parser(parent_parser=True) cut_parser = cutout_parser(parent_parser=True) parser = argparse.ArgumentParser( formatter_class=UltimateHelpFormatter, parents=[gen_parser, work_parser, cut_parser], description=cut_parser.description, ) args = parser.parse_args() verbose = args.verbose if verbose: logger.setLevel(logging.INFO) test_db( host=args.host, username=args.username, password=args.password, ) main(args)
if __name__ == "__main__": cli()