Source code for arrakis.imager

#!/usr/bin/env python3
"""Arrkis imager"""

from __future__ import annotations

import argparse
import hashlib
import logging
import os
import pickle
import shutil
from concurrent.futures import ThreadPoolExecutor
from glob import glob
from pathlib import Path
from subprocess import CalledProcessError
from typing import Any
from typing import NamedTuple as Struct

import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from astropy.io import fits
from astropy.stats import mad_std
from astropy.table import Table
from astropy.visualization import (
    ImageNormalize,
    SqrtStretch,
)
from fitscube import combine_fits
from fixms.fix_ms_corrs import fix_ms_corrs
from fixms.fix_ms_dir import fix_ms_dir
from prefect import flow, get_run_logger, task
from racs_tools import beamcon_2D
from skimage.transform import resize
from spython.main import Client as sclient
from tqdm.auto import tqdm

from arrakis.logger import TqdmToLogger, UltimateHelpFormatter, logger
from arrakis.utils.exceptions import DivergenceError
from arrakis.utils.io import parse_env_path
from arrakis.utils.meta import my_ceil
from arrakis.utils.msutils import (
    beam_from_ms,
    field_idx_from_ms,
    field_name_from_ms,
    get_pol_axis,
    wsclean,
)
from arrakis.utils.pipeline import (
    logo_str,
    upload_image_as_artifact_task,
    workdir_arg_parser,
)

mpl.use("Agg")

[docs] TQDM_OUT = TqdmToLogger(logger, level=logging.INFO)
[docs] class ImageSet(Struct): """Container to organise files related to t he imaging of a measurement set."""
[docs] ms: Path
"""Path to the measurement set that was imaged."""
[docs] prefix: str
"""Prefix used for the wsclean output files."""
[docs] image_lists: dict[str, list[str]]
"""Dictionary of lists of images. The keys are the polarisations and the values are the list of images for that polarisation."""
[docs] aux_lists: dict[tuple[str, str], list[str]] | None = None
"""Dictionary of lists of auxiliary images. The keys are a tuple of the polarisation and the image type, and the values are the list of images for that polarisation and image type."""
[docs] class MFSImage(Struct): """Representation of a multi-frequency synthesis image."""
[docs] image: np.ndarray
"""The image data."""
[docs] model: np.ndarray
"""The model data."""
[docs] residual: np.ndarray
"""The residual data."""
@task(name="Get pol. axis")
[docs] def get_pol_axis_task( ms: Path, feed_idx: int | None = None, col: str = "RECEPTOR_ANGLE" ) -> float: return get_pol_axis(ms=ms, feed_idx=feed_idx, col=col).to(u.deg).value
@task(name="Merge ImageSets")
[docs] def merge_imagesets(image_sets: list[ImageSet | None]) -> ImageSet: """Merge a collection of ImageSets into a single ImageSet. Args: image_sets (List[ImageSet]): Collection of ImageSets to merge. Returns: ImageSet: A single ImageSet containing all the images from the input ImageSets. """ logger = get_run_logger() logger.info(f"Merging {len(image_sets)} ImageSets.") # Remove any None values image_sets_list = [image_set for image_set in image_sets if image_set is not None] ms = image_sets_list[0].ms prefix = image_sets_list[0].prefix for image_set in image_sets_list: assert image_set.ms.name == ms.name, ( f"{image_set.ms.name=} does not match {ms.name=}" ) assert image_set.prefix == prefix, ( f"{image_set.prefix=} does not match {prefix=}" ) image_lists = {} aux_lists = {} for image_set in image_sets_list: for pol, images in image_set.image_lists.items(): if pol not in image_lists: image_lists[pol] = [] image_lists[pol].extend(images) if image_set.aux_lists: for (pol, aux), aux_images in image_set.aux_lists.items(): if (pol, aux) not in aux_lists: aux_lists[(pol, aux)] = [] aux_lists[(pol, aux)].extend(aux_images) return ImageSet(ms=ms, prefix=prefix, image_lists=image_lists, aux_lists=aux_lists)
[docs] def get_mfs_image( prefix_str: str, pol: str, small_size: tuple[int, int] = (512, 512) ) -> MFSImage: """Get the MFS image from the image set. Returns: MFSImage: The MFS image. """ mfs_image_name = ( f"{prefix_str}-MFS-image.fits" if pol == "I" else f"{prefix_str}-MFS-{pol}-image.fits" ) mfs_residual_name = ( f"{prefix_str}-MFS-residual.fits" if pol == "I" else f"{prefix_str}-MFS-{pol}-residual.fits" ) big_image = fits.getdata(mfs_image_name).squeeze() big_residual = fits.getdata(mfs_residual_name).squeeze() big_model = big_image - big_residual small_image = resize(big_image, small_size) small_model = resize(big_model, small_size) small_residual = resize(big_residual, small_size) return MFSImage(image=small_image, model=small_model, residual=small_residual)
@task(name="Make Validation Plots", persist_result=True)
[docs] def make_validation_plots(prefix: Path, pols: str) -> None: """Make validation plots for the images. Args: prefix (Path): Prefix of the images. pols (str): Polarisation to make the plots for. """ prefix_str = prefix.resolve().as_posix() for stokes in pols: mfs_image = get_mfs_image(prefix_str, stokes) fig, axs = plt.subplots(1, 3, figsize=(15, 5)) for ax, sub_image, title in zip( axs, mfs_image, ("Image", "Model (conv.)", "Residual") ): sub_image = np.abs(sub_image) norm = ImageNormalize(mfs_image.residual, vmin=0, stretch=SqrtStretch()) _ = ax.imshow(sub_image, origin="lower", norm=norm, cmap="cubehelix") ax.set_title(title) ax.get_yaxis().set_visible(False) ax.get_xaxis().set_visible(False) fig.suptitle(f"abs(Stokes {stokes}) - {prefix.name}") # Remove the space between the plots fig.subplots_adjust(wspace=0, hspace=0) fig_name = Path(f"{prefix.name}_abs_stokes_{stokes}.png") fig.savefig(fig_name, bbox_inches="tight", dpi=72) plt.close(fig) uuid = upload_image_as_artifact_task.fn( fig_name, description=f"abs(Stokes {stokes}) - {prefix.name}", ) logger.info(f"Uploaded {fig_name} to {uuid}")
[docs] def get_wsclean(wsclean: Path | str) -> Path: """Pull wsclean image from dockerhub (or wherver). Args: version (str, optional): wsclean image tag. Defaults to "3.1". Returns: Path: Path to wsclean image. """ sclient.load(str(wsclean)) if isinstance(wsclean, str): return Path(sclient.pull(wsclean)) return wsclean
[docs] def cleanup_imageset(purge: bool, image_set: ImageSet) -> None: """Delete images associated with an input ImageSet Args: purge (bool): Whether files will be deleted or skipped. image_set (ImageSet): Collection of files that will be removed. """ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) if not purge: logger.info("Not purging intermediate files") return for pol, image_list in image_set.image_lists.items(): logger.critical(f"Removing {pol=} images for {image_set.ms}") for image in image_list: logger.critical(f"Removing {image}") try: os.remove(image) except FileNotFoundError: logger.critical(f"{image} not available for deletion. ") # The aux images are the same between the native images and the smoothed images, # they were just copied across directly without modification if image_set.aux_lists: logger.critical("Removing auxiliary images. ") for (pol, aux), aux_list in image_set.aux_lists.items(): for aux_image in aux_list: try: logger.critical(f"Removing {aux_image}") os.remove(aux_image) except FileNotFoundError: logger.critical(f"{aux_image} not available for deletion. ") return
[docs] def get_prefix( ms: Path, out_dir: Path, ) -> Path: """Derive a consistent prefix style from a input MS name. Args: ms (Path): Path to a Measurement Set that a prefix will be derived from out_dir (Path): The final location that wsclean output data will be written to Returns: Path: The prefix, including the output directory name. """ field = field_name_from_ms(ms.resolve(strict=True).as_posix()) beam = beam_from_ms(ms.resolve(strict=True).as_posix()) prefix = f"image.{field}.contcube.beam{beam:02}" return out_dir / prefix
[docs] def run_wsclean_singuarlity( command: str, simage: Path, out_dir: Path, root_dir: Path, ) -> None: logger = get_run_logger() logger.info(f"Running wsclean with command: {command}") try: output = sclient.execute( image=simage.resolve(strict=True).as_posix(), command=command.split(), bind=f"{out_dir}:{out_dir}, {root_dir.resolve(strict=True).as_posix()}:{root_dir.resolve(strict=True).as_posix()}", return_result=True, quiet=False, stream=True, ) for line in output: logger.info(line.rstrip()) # Catch divergence - look for the string 'KJy' in the output if "KJy" in line: raise DivergenceError( f"Detected divergence in wsclean output: {line.rstrip()}" ) except CalledProcessError as e: logger.error(f"Failed to run wsclean with command: {command}") logger.error(f"Stdout: {e.stdout}") logger.error(f"Stderr: {e.stderr}") logger.error(f"{e=}") raise e
@task(name="Image Beam", persist_result=True)
[docs] def image_beam( ms: Path, field_idx: int, out_dir: Path, prefix: Path, simage: Path, temp_dir_wsclean: Path, temp_dir_images: Path, pols: str = "IQU", nchan: int = 36, scale: float = 2.5, npix: int = 4096, join_polarizations: bool = True, join_channels: bool = True, squared_channel_joining: bool = True, mgain: float = 0.7, niter: int = 100_000, auto_mask: float = 3, force_mask_rounds: int | None = None, auto_threshold: float = 1, gridder: str | None = None, robust: float = -0.5, mem: float = 90, absmem: float | None = None, taper: float | None = None, minuv_l: float = 0.0, parallel_deconvolution: int | None = None, nmiter: int | None = None, local_rms: bool = False, local_rms_window: float | None = None, multiscale: bool = False, multiscale_scale_bias: float | None = None, multiscale_scales: str | None = "0,2,4,8,16,32,64,128", data_column: str = "CORRECTED_DATA", no_mf_weighting: bool = False, no_update_model_required: bool = True, beam_fitting_size: float | None = 1.25, disable_pol_local_rms: bool = False, disable_pol_force_mask_rounds: bool = False, ) -> ImageSet: """Image a single beam""" logger = get_run_logger() # Evaluate the temp directory if a ENV variable is used temp_dir_images = parse_env_path(temp_dir_images) if temp_dir_images != out_dir: # Copy the MS to the temp directory ms_temp = temp_dir_images / ms.name logger.info(f"Copying {ms} to {ms_temp}") ms_temp = ms_temp.resolve(strict=False) shutil.copytree(ms, ms_temp, dirs_exist_ok=True) ms = ms_temp # Update the prefix prefix = temp_dir_images / prefix.name temp_dir_wsclean = parse_env_path(temp_dir_wsclean) # Make temp MS to allow parallel imaging ms_temp = ms.with_suffix(f".{pols}.temp.ms") logger.info(f"Copying {ms} to {ms_temp}") shutil.copytree(ms, ms_temp, dirs_exist_ok=True) # Catch mismatched args if not local_rms: logger.warning( f"Local RMS is disabled. Setting local_rms_window to None. Was set to {local_rms_window}." ) local_rms_window = None if not multiscale: logger.warning( f"Multiscale is disabled. Setting multiscale_scale_bias to None. Was set to {multiscale_scale_bias}." ) multiscale_scale_bias = None logger.warning( f"Multiscale is disabled. Setting multiscale_scales to None. Was set to {multiscale_scales}." ) multiscale_scales = None if squared_channel_joining: logger.info( "Squared channel joining is enabled - scaling auto_mask and auto_threshold by power of 2" ) auto_mask = my_ceil(auto_mask**2, 2) auto_threshold = my_ceil(auto_threshold**2, 2) if disable_pol_local_rms and pols != "I": logger.info("Disabling local RMS for polarisation images") local_rms = False local_rms_window = None if disable_pol_force_mask_rounds and pols != "I": logger.info("Disabling force mask rounds for polarisation images") force_mask_rounds = None command = wsclean( mslist=[ms_temp.resolve(strict=True).as_posix()], temp_dir=( temp_dir_wsclean.resolve(strict=True).as_posix() if temp_dir_wsclean is not None else None ), use_mpi=False, name=prefix.resolve().as_posix(), pol=pols, verbose=True, channels_out=nchan, parallel_gridding=nchan, scale=f"{scale}asec", size=f"{npix} {npix}", join_polarizations=join_polarizations, join_channels=join_channels, squared_channel_joining=squared_channel_joining, mgain=mgain, niter=niter, auto_mask=auto_mask, force_mask_rounds=force_mask_rounds, auto_threshold=auto_threshold, gridder=gridder, weight=f"briggs {robust}", log_time=False, mem=mem, abs_mem=absmem, taper_gaussian=f"{taper}asec" if taper else None, field=field_idx, parallel_deconvolution=parallel_deconvolution, minuv_l=minuv_l, nmiter=nmiter, local_rms=local_rms, local_rms_window=local_rms_window, # Avoid multiscale when using squared channel joining multiscale=multiscale if not squared_channel_joining else False, multiscale_scale_bias=multiscale_scale_bias if not squared_channel_joining else None, multiscale_scales=multiscale_scales if not squared_channel_joining else None, data_column=data_column, no_mf_weighting=no_mf_weighting, no_update_model_required=no_update_model_required, beam_fitting_size=beam_fitting_size, ) root_dir = ms.parent try: run_wsclean_singuarlity( command=command, simage=simage, out_dir=out_dir, root_dir=root_dir, ) except DivergenceError as de: logger.error(f"Detected divergence in wsclean output: {de}") new_pix = npix + 1024 new_command = command.replace(f"{npix} {npix}", f"{new_pix} {new_pix}") logger.critical( f"Rerunning wsclean with larger image size: {new_pix}x{new_pix}" ) run_wsclean_singuarlity( command=new_command, simage=simage, out_dir=out_dir, root_dir=root_dir, ) # Purge ms_temp shutil.rmtree(ms_temp) suffixes: list[str] = ["image", "model", "psf", "residual", "dirty"] if temp_dir_images != out_dir: # Copy the images to the output directory logger.info(f"Copying images to {out_dir}") # Suffixes are: # ["image", "model", "psf", "residual", "dirty"] # For single pol files are # {prefix}-{chan:02d}-{suffix}.fits # For multiple pols files are # {prefix}-{chan:02d}-{pol}-{suffix}.fits all_fits_files = [] for pol in pols: for suffix in suffixes: # Get channel images globstr = ( f"{prefix.name}-*[0-9]-{suffix}.fits" if len(pols) == 1 else f"{prefix.name}-*[0-9]-{pol}-{suffix}.fits" ) sub_fits_files = list(temp_dir_images.glob(globstr)) all_fits_files.extend(sub_fits_files) # Get the MFS image mfs_globstr = ( f"{prefix.name}-MFS-{pol}-{suffix}.fits" if len(pols) > 1 else f"{prefix.name}-MFS-{suffix}.fits" ) mfs_files = list(temp_dir_images.glob(mfs_globstr)) all_fits_files.extend(mfs_files) for fits_file in tqdm(all_fits_files, desc="Copying images", file=TQDM_OUT): logger.info(f"Copying {fits_file} to {out_dir}") shutil.copy(fits_file, out_dir) # Purge the temp directory fits_file.unlink() # Update the prefix prefix = out_dir / prefix.name # Remove the temp directory shutil.rmtree(temp_dir_images) prefix_str = prefix.resolve().as_posix() # Check rms of image to check for divergence for pol in pols: mfs_image = ( f"{prefix_str}-MFS-image.fits" if len(pols) == 1 else f"{prefix_str}-MFS-{pol}-image.fits" ) rms = mad_std(fits.getdata(mfs_image), ignore_nan=True) if rms > 1: # raise ValueError(f"RMS of {rms} is too high in image {mfs_image}, try imaging with lower mgain {mgain - 0.1}") logger.error( f"RMS of {rms} is too high in image {mfs_image}, try imaging with lower mgain {mgain - 0.1}" ) # Get images image_lists = {} aux_lists = {} aux_suffixes = suffixes[1:] for pol in pols: imglob = ( f"{prefix_str}-*[0-9]-image.fits" if len(pols) == 1 else f"{prefix_str}-*[0-9]-{pol}-image.fits" ) image_list = sorted(glob(imglob)) image_lists[pol] = image_list logger.info(f"Found {len(image_list)} images for {pol=} {ms}.") for aux in aux_suffixes: aux_list = ( sorted(glob(f"{prefix_str}-*[0-9]-{aux}.fits")) if len(pols) == 1 or aux == "psf" else sorted(glob(f"{prefix_str}-*[0-9]-{pol}-{aux}.fits")) ) aux_lists[(pol, aux)] = aux_list logger.info(f"Found {len(aux_list)} images for {pol=} {aux=} {ms}.") logger.info("Constructing ImageSet") image_set = ImageSet( ms=ms, prefix=prefix_str, image_lists=image_lists, aux_lists=aux_lists ) logger.debug(f"{image_set=}") return image_set
@task(name="Make Cube", persist_result=True)
[docs] def make_cube( pol: str, image_set: ImageSet, common_beam_pkl: Path, pol_angle_deg: float, aux_mode: str | None = None, ) -> tuple[Path, Path]: """Make a cube from the images""" logger = get_run_logger() logger.info(f"Creating cube for {pol=} {image_set.ms=}") image_list = [Path(i) for i in image_set.image_lists[pol]] image_type = "restored" if aux_mode is None else aux_mode # Create a cube name old_name = image_list[0] out_dir = os.path.dirname(old_name) old_base = os.path.basename(old_name) new_base = old_base b_idx = new_base.find("beam") + len("beam") + 2 sub = new_base[b_idx:] new_base = new_base.replace(sub, ".conv.fits") new_base = new_base.replace("image", f"image.{image_type}.{pol.lower()}") new_name = Path(out_dir) / new_base # First combine images into cubes _ = combine_fits( file_list=image_list, out_cube=new_name, create_blanks=True, overwrite=True, ) with fits.open(new_name, mode="denywrite", memmap=True) as hdu_list: new_header = hdu_list[0].header data_cube = hdu_list[0].data # Add pol angle to header new_header["INSTRUMENT_RECEPTOR_ANGLE"] = ( pol_angle_deg, "Orig. pol. axis rotation angle in degrees", ) tmp_header = new_header.copy() # Need to swap NAXIS 3 and 4 to make LINMOS happy - booo for a, b in ((3, 4), (4, 3)): new_header[f"CTYPE{a}"] = tmp_header[f"CTYPE{b}"] new_header[f"CRPIX{a}"] = tmp_header[f"CRPIX{b}"] new_header[f"CRVAL{a}"] = tmp_header[f"CRVAL{b}"] new_header[f"CDELT{a}"] = tmp_header[f"CDELT{b}"] new_header[f"CUNIT{a}"] = tmp_header[f"CUNIT{b}"] # Cube is currently STOKES, FREQ, RA, DEC - needs to be FREQ, STOKES, RA, DEC data_cube = np.moveaxis(data_cube, 1, 0) # Calculate rms noise rmss_arr = mad_std(data_cube, axis=(1, 2, 3), ignore_nan=True) # Deserialise beam with open(common_beam_pkl, "rb") as f: common_beam = pickle.load(f) new_header = common_beam.attach_to_header(new_header) fits.writeto(new_name, data_cube, new_header, overwrite=True) logger.info(f"Written {new_name}") # Write out weights # Must be of the format: # #Channel Weight # 0 1234.5 # 1 6789.0 # etc. new_w_name = Path( new_name.as_posix().replace(f"image.{image_type}", f"weights.{image_type}") ).with_suffix(".txt") data = dict( Channel=np.arange(len(rmss_arr)), Weight=1 / rmss_arr**2, # Want inverse variance ) tab = Table(data) tab.write(new_w_name, format="ascii.commented_header", overwrite=True) return new_name, new_w_name
@task(name="Get Beam", persist_result=True)
[docs] def get_beam(image_set: ImageSet, cutoff: float | None) -> Path: """Derive a common resolution across all images within a set of ImageSet Args: image_set (ImageSet): ImageSet that a common resolution will be derived for cuttoff (float, optional): The maximum major axis of the restoring beam that is allowed when searching for the lowest common beam. Images whose restoring beam's major acis is larger than this are ignored. Defaults to None. Returns: Path: Path to the pickled beam object """ logger = get_run_logger() # convert dict to list image_list = [] for _, sub_image_list in image_set.image_lists.items(): image_list.extend(sub_image_list) # Consistent hash between runs image_list = sorted(image_list) logger.info(f"The length of the image list is: {len(image_list)}") # Create a unique hash for the beam log filename image_hash = hashlib.md5("".join(image_list).encode()).hexdigest() try: common_beam, _ = beamcon_2D.get_common_beam(files=image_list, cutoff=cutoff) except Exception as e: import sys import traceback the = traceback.TracebackException.from_exception(e) logger.error(f"Local {''.join(the.format())}") f = sys.exc_info()[2].tb_frame f = f.f_back while f is not None: the.stack.append( traceback.FrameSummary( f.f_code.co_filename, f.f_lineno, f.f_code.co_name ) ) f = f.f_back logger.error(f"Full {''.join(the.format())}") raise e logger.info(f"The common beam is: {common_beam=}") if any([np.isnan(common_beam.major), np.isnan(common_beam.minor)]): raise ValueError("Common beam is NaN, consider raising the cutoff.") # serialise the beam common_beam_pkl = Path(f"beam_{image_hash}.pkl") with open(common_beam_pkl, "wb") as f: logger.info(f"Creating {common_beam_pkl}") pickle.dump(common_beam, f) return common_beam_pkl
@task(name="Smooth ImageSet", persist_result=True)
[docs] def smooth_imageset( image_set: ImageSet, common_beam_pkl: Path, cutoff: float | None = None, aux_mode: str | None = None, ) -> ImageSet: """Smooth all images described within an ImageSet to a desired resolution Args: image_set (ImageSet): Container whose image_list will be convolved to common resolution common_beam_pkl (Path): Location of pickle file with beam description cutoff (Optional[float], optional): PSF cutoff passed to the beamcon_2D worker. Defaults to None. aux_model (Optional[str], optional): The image type in the `aux_lists` property of `image_set` that contains the images to smooth. If not set then the `image_lists` property of `image_set` is used. Defaults to None. Returns: ImageSet: A copy of `image_set` pointing to the smoothed images. Note the `aux_images` property is not carried forward. """ # Smooth image logger = get_run_logger() # Deserialise the beam with open(common_beam_pkl, "rb") as f: logger.info(f"Loading common beam from {common_beam_pkl}") common_beam = pickle.load(f) logger.info(f"{common_beam=}") logger.info(f"Smoothing {image_set.ms} images") images_to_smooth: dict[str, list[str]] if aux_mode is None: images_to_smooth = image_set.image_lists else: logger.info(f"Extracting images for {aux_mode=}.") assert image_set.aux_lists is not None, f"{image_set=} has empty aux_lists." images_to_smooth = { pol: images for (pol, img_type), images in image_set.aux_lists.items() if aux_mode == img_type } sm_images = {} with ThreadPoolExecutor() as executor: for pol, pol_images in images_to_smooth.items(): logger.info(f"Smoothing {pol=} for {image_set.ms}") for img in pol_images: logger.info(f"Smoothing {img}") last_result = executor.submit( beamcon_2D.beamcon_2d_on_fits, file=Path(img), outdir=None, new_beam=common_beam, conv_mode="robust", suffix="conv", cutoff=cutoff, ) sm_images[pol] = [ image.replace(".fits", ".conv.fits") for image in pol_images ] # Wait on all the futures _ = last_result.result() return ImageSet( ms=image_set.ms, prefix=image_set.prefix, image_lists=sm_images, )
@task(name="Cleanup")
[docs] def cleanup( purge: bool, image_sets: list[ImageSet], ignore_files: list[Any] | None = None ) -> None: """Utility to remove all images described by an collection of ImageSets. Internally called `cleanup_imageset`. Args: purge (bool): Whether files are actually removed or skipped. image_sets (List[ImageSet]): Collection of ImageSets that would be deleted ignore_files (Optional, List[Any]): Collection of items to ignore. Nothing is done with this and is purely used to exploit the dask dependency tracking. """ logger = get_run_logger() logger.warning(f"Ignoring files in {ignore_files=}. ") if not purge: logger.info("Not purging intermediate files") return for image_set in image_sets: cleanup_imageset(purge=purge, image_set=image_set) return
@task(name="Fix MeasurementSet Directions")
[docs] def fix_ms(ms: Path) -> Path: """Apply the corrections to the FEED table of a measurement set that is required for the ASKAP measurement sets. Args: ms (Path): Path to the measurement set to fix. Returns: Path: Path to the corrected measurement set. """ fix_ms_dir(ms.resolve(strict=True).as_posix()) return ms
@task(name="Fix MeasurementSet Correlations")
[docs] def fix_ms_askap_corrs(ms: Path, *args, **kwargs) -> Path: """Applies a correction to raw telescope polarisation products to rotate them to the wsclean expected form. This is essentially related to the third-axis of ASKAP and reorientating its 'X' and 'Y's. Args: ms (Path): Path to the measurement set to be corrected. Returns: Path: Path of the measurementt set containing the corrections. """ logger = get_run_logger() logger.info(f"Correcting {str(ms)} correlations for wsclean. ") fix_ms_corrs(ms=ms, *args, **kwargs) return ms
@flow(name="Imager")
[docs] def main( msdir: Path, out_dir: Path, num_beams: int = 36, temp_dir_images: Path | None = None, temp_dir_wsclean: Path | None = None, cutoff: float | None = None, robust: float = -0.5, pols: str = "IQU", nchan: int = 36, size: int = 6074, scale: float = 2.5, mgain: float = 0.8, niter: int = 100_000, auto_mask: float = 3, force_mask_rounds: int | None = None, auto_threshold: float = 1, taper: float | None = None, purge: bool = False, minuv: float = 0.0, parallel_deconvolution: int | None = None, gridder: str | None = None, nmiter: int | None = None, local_rms: bool = False, local_rms_window: float | None = None, wsclean_path: Path | str = "docker://alecthomson/wsclean:latest", multiscale: bool | None = None, multiscale_scale_bias: float | None = None, multiscale_scales: str | None = "0,2,4,8,16,32,64,128", absmem: float | None = None, make_residual_cubes: bool | None = False, ms_glob_pattern: str = "scienceData*_averaged_cal.leakage.ms", data_column: str = "CORRECTED_DATA", skip_fix_ms: bool = False, no_mf_weighting: bool = False, disable_pol_local_rms: bool = False, disable_pol_force_mask_rounds: bool = False, ): """Arrakis imager flow Args: msdir (Path): Path to the directory containing the MS files. out_dir (Path): Path to the directory where the images will be written. num_beams (int, optional): Number of beams to image. Defaults to 36. temp_dir_images (Optional[Path], optional): Path for temporary files to be written. Defaults to None. temp_dir_wsclean (Optional[Path], optional): Path for temporary files to be written by WSClean. Defaults to None. cutoff (Optional[float], optional): WSClean cutoff. Defaults to None. robust (float, optional): WSClean Briggs robust parameter. Defaults to -0.5. pols (str, optional): WSClean polarisations. Defaults to "IQU". nchan (int, optional): WSClean number of output channels. Defaults to 36. size (int, optional): WSClean image size. Defaults to 6074. scale (float, optional): WSClean pixel size (arcseconds). Defaults to 2.5. mgain (float, optional): WSClean mgain. Defaults to 0.8. niter (int, optional): WSClean niter. Defaults to 100_000. auto_mask (float, optional): WSClean automatic masking (in SNR). Defaults to 3. force_mask_rounds (Union[int, None], optional): WSClean force mask rounds (requires modified WSClean). Defaults to None. auto_threshold (float, optional): WSClean auto threshold (in SNR). Defaults to 1. taper (Union[float, None], optional): WSClean taper (in arcsec). Defaults to None. purge (bool, optional): Purge auxiliary files after imaging. Defaults to False. minuv (float, optional): WSClean minuv-l. Defaults to 0.0. parallel_deconvolution (Optional[int], optional): WSClean parallel deconvolution. Defaults to None. gridder (Optional[str], optional): WSClean gridder. Defaults to None. nmiter (Optional[int], optional): WSClean nmiter. Defaults to None. local_rms (bool, optional): WSClean local_rms. Defaults to False. local_rms_window (Optional[float], optional): WSClean local_rms_window. Defaults to None. wsclean_path (Path | str, optional): Path or URL for WSClean container. Defaults to "docker://alecthomson/wsclean:latest". multiscale (Optional[bool], optional): WSClean multiscale. Defaults to None. multiscale_scale_bias (Optional[float], optional): WSClean multiscale bias. Defaults to None. multiscale_scales (Optional[str], optional): WSClean scales. Defaults to "0,2,4,8,16,32,64,128". absmem (Optional[float], optional): WSClean absmem usage. Defaults to None. make_residual_cubes (Optional[bool], optional): Make resiudal image cubes. Defaults to False. ms_glob_pattern (str, optional): Globe pattern for MS files. Defaults to "scienceData*_averaged_cal.leakage.ms". data_column (str, optional): Data column to image. Defaults to "CORRECTED_DATA". skip_fix_ms (bool, optional): Apply FixMS. Defaults to False. no_mf_weighting (bool, optional): WSClean no_mf_weighting. Defaults to False. disable_pol_local_rms (bool, optional): Disable local RMS for polarisation images. Defaults to False. disable_pol_force_mask_rounds (bool, optional): Disable force mask rounds for polarisation images. Defaults to False. """ simage = get_wsclean(wsclean=wsclean_path) logger.info(f"Searching {msdir} for MS matching {ms_glob_pattern}.") mslist = sorted(msdir.glob(ms_glob_pattern)) assert (len(mslist) > 0) & (len(mslist) == num_beams), ( f"Incorrect number of MS files found: {len(mslist)} / {num_beams} - glob pattern: {ms_glob_pattern}" ) logger.info(f"Will image {len(mslist)} MS files in {msdir} to {out_dir}") cleans = [] if temp_dir_wsclean is None: temp_dir_wsclean = out_dir logger.info(f"Using {temp_dir_wsclean} as temp directory for WSClean") if temp_dir_images is None: # Don't allow purge if temp_dir_images is None logger.warning("No temp directory for images specified, disabling purge.") purge = False temp_dir_images = out_dir logger.info(f"Using {temp_dir_images} as temp directory for images") # Do this in serial since CASA gets upset prefixs = {} field_idxs = {} for ms in tqdm(mslist, "Getting metadata", file=TQDM_OUT): prefix = get_prefix(ms, out_dir) prefixs[ms] = prefix field_idxs[ms] = field_idx_from_ms(ms.resolve(strict=True).as_posix()) cube_aux_modes = (None, "residual") if make_residual_cubes else (None,) # Image_sets will be a container that represents the output wsclean image products # produced for each beam. A single ImageSet is a container for a single beam. ms_list_fixed = [] pol_angles = [] for ms in mslist: logger.info(f"Imaging {ms}") # Apply Emil's fix for MSs feed centre if not skip_fix_ms: ms_fix = fix_ms.submit(ms) ms_fix = fix_ms_askap_corrs.submit( ms=ms_fix, data_column="DATA", corrected_data_column=data_column ) pol_angle_deg = get_pol_axis_task.submit( ms_fix, col="INSTRUMENT_RECEPTOR_ANGLE" ) else: ms_fix = ms pol_angle_deg = get_pol_axis_task.submit(ms_fix, col="RECEPTOR_ANGLE") ms_list_fixed.append(ms_fix) pol_angles.append(pol_angle_deg) for ms, ms_fix, pol_angle_deg in zip(mslist, ms_list_fixed, pol_angles): # Image with wsclean # split out stokes I and QUV if "I" in pols: image_set_I = image_beam.submit( ms=ms_fix, field_idx=field_idxs[ms], out_dir=out_dir, temp_dir_wsclean=temp_dir_wsclean, temp_dir_images=temp_dir_images, prefix=prefixs[ms], simage=simage.resolve(strict=True), robust=robust, pols="I", join_polarizations=False, # Only do I squared_channel_joining=False, # Dont want to square I nchan=nchan, scale=scale, npix=size, mgain=mgain, niter=niter, auto_mask=auto_mask, force_mask_rounds=force_mask_rounds, auto_threshold=auto_threshold, taper=taper, minuv_l=minuv, parallel_deconvolution=parallel_deconvolution, gridder=gridder, nmiter=nmiter, local_rms=local_rms, local_rms_window=local_rms_window, multiscale=multiscale, multiscale_scale_bias=multiscale_scale_bias, multiscale_scales=multiscale_scales, absmem=absmem, data_column=data_column, no_mf_weighting=no_mf_weighting, disable_pol_local_rms=disable_pol_local_rms, disable_pol_force_mask_rounds=disable_pol_force_mask_rounds, ) else: image_set_I = None image_set_pol = image_beam.submit( ms=ms_fix, field_idx=field_idxs[ms], out_dir=out_dir, temp_dir_wsclean=temp_dir_wsclean, temp_dir_images=temp_dir_images, prefix=prefixs[ms], simage=simage.resolve(strict=True), robust=robust, pols=pols.replace("I", ""), # There is no 'I' in polarisation... join_polarizations=len(pols) > 1, squared_channel_joining=True, nchan=nchan, scale=scale, npix=size, mgain=mgain, niter=niter, auto_mask=auto_mask, force_mask_rounds=force_mask_rounds, auto_threshold=auto_threshold, taper=taper, minuv_l=minuv, parallel_deconvolution=parallel_deconvolution, gridder=gridder, nmiter=nmiter, local_rms=local_rms, local_rms_window=local_rms_window, multiscale=multiscale, multiscale_scale_bias=multiscale_scale_bias, multiscale_scales=multiscale_scales, absmem=absmem, data_column=data_column, no_mf_weighting=no_mf_weighting, disable_pol_local_rms=disable_pol_local_rms, disable_pol_force_mask_rounds=disable_pol_force_mask_rounds, ) image_set = merge_imagesets.submit([image_set_I, image_set_pol]) make_validation_plots.submit( prefix=prefixs[ms], pols=pols, wait_for=[image_set], ) # Compute the smallest beam that all images can be convolved to. # This requires all imaging rounds to be completed, so the total # set of ImageSets are first derived before this is called. common_beam_pkl = get_beam.submit( image_set=image_set, cutoff=cutoff, ) # With the final beam each *image* in the ImageSet across IQU are # smoothed and then form the cube for each stokes. # Per loop containers since we are iterating over image modes clean_sm_image_sets = [] for aux_mode in cube_aux_modes: # Smooth the *images* in an ImageSet across all Stokes. This # limits the number of workers to 36, i.e. this is operating # beamwise sm_image_set = smooth_imageset.submit( image_set, common_beam_pkl=common_beam_pkl, cutoff=cutoff, aux_mode=aux_mode, ) # Make a cube. This is operating across beams and stokes cube_images = [ make_cube.submit( pol=pol, image_set=sm_image_set, common_beam_pkl=common_beam_pkl, pol_angle_deg=pol_angle_deg, aux_mode=aux_mode, wait_for=[sm_image_set], ) for pol in pols ] # Clean up smoothed images files. Note the # ignore_files that is used to preserve the # dependency between dask tasks clean = cleanup.submit( purge=purge, image_sets=[sm_image_set], wait_for=[cube_images], ) clean_sm_image_sets.append(clean) # Now clean the original output images from wscean clean = cleanup.submit( purge=purge, image_sets=[image_set], wait_for=clean_sm_image_sets, ) cleans.append(clean) logger.info("Imager finished!") return
[docs] def imager_parser(parent_parser: bool = False) -> argparse.ArgumentParser: """Return the argument parser for the imager routine. Args: parent_parser (bool, optional): Ensure the parser is configured so it can be added as a parent to a new parser. This will disables the -h/--help action from being generated. Defaults to False. Returns: argparse.ArgumentParser: Arguments required for the imager routine """ # Help string to be shown using the -h option descStr = f""" {logo_str} {__doc__} """ # Parse the command line options img_parser = argparse.ArgumentParser( add_help=not parent_parser, description=descStr, formatter_class=UltimateHelpFormatter, ) parser = img_parser.add_argument_group("imaging arguments") parser.add_argument( "msdir", type=Path, help="Directory containing MS files", ) parser.add_argument( "--temp_dir_wsclean", type=Path, help="Temporary directory for WSClean to store intermediate files", ) parser.add_argument( "--temp_dir_images", type=Path, help="Temporary directory for to store intermediate image files", ) parser.add_argument( "--psf_cutoff", type=float, help="Cutoff for smoothing in units of arcseconds. ", ) parser.add_argument( "--robust", type=float, default=-0.5, ) parser.add_argument( "--nchan", type=int, default=36, ) parser.add_argument( "--pols", type=str, default="IQU", ) parser.add_argument( "--size", type=int, default=4096, ) parser.add_argument( "--scale", type=float, default=2.5, ) parser.add_argument( "--mgain", type=float, default=0.8, ) parser.add_argument( "--niter", type=int, default=100_000, ) parser.add_argument( "--nmiter", type=int, default=None, ) parser.add_argument( "--auto_mask", type=float, default=3.0, ) parser.add_argument( "--auto_threshold", type=float, default=1.0, ) parser.add_argument( "--local_rms", action="store_true", ) parser.add_argument( "--local_rms_window", type=float, default=None, ) parser.add_argument( "--force_mask_rounds", type=int, default=None, ) parser.add_argument( "--gridder", type=str, default=None, choices=["direct-ft", "idg", "wgridder", "tuned-wgridder", "wstacking"], ) parser.add_argument( "--taper", type=float, default=None, ) parser.add_argument( "--minuv", type=float, default=0.0, ) parser.add_argument( "--parallel", type=int, default=None, ) parser.add_argument( "--purge", action="store_true", help="Purge intermediate files", ) parser.add_argument( "--mpi", action="store_true", help="Use MPI", ) parser.add_argument( "--multiscale", action="store_true", help="Use multiscale clean", ) parser.add_argument( "--multiscale_scale_bias", type=float, default=None, help="The multiscale scale bias term provided to wsclean. ", ) parser.add_argument( "--multiscale_scales", type=str, default="0,2,4,8,16,32,64,128", help="The scales used in the multiscale clean. ", ) parser.add_argument( "--absmem", type=float, default=None, help="Absolute memory limit in GB", ) parser.add_argument( "--make_residual_cubes", action="store_true", help="Create residual cubes as well as cubes from restored images. ", ) parser.add_argument( "--ms_glob_pattern", type=str, default="scienceData*_averaged_cal.leakage.ms", help="The pattern used to search for measurement sets. ", ) parser.add_argument( "--data_column", type=str, default="CORRECTED_DATA", help="Which column in the measurement set to image. ", ) parser.add_argument( "--no_mf_weighting", action="store_true", help="Do not use multi-frequency weighting. ", ) parser.add_argument( "--skip_fix_ms", action="store_true", default=False, help="Do not apply the ASKAP MS corrections from the package fixms. ", ) parser.add_argument( "--num_beams", type=int, help="Number of beams to image", default=36, ) parser.add_argument( "--disable_pol_local_rms", action="store_true", help="Disable local RMS for polarisation images", ) parser.add_argument( "--disable_pol_force_mask_rounds", action="store_true", help="Disable force mask rounds for polarisation images", ) group = parser.add_argument_group("wsclean container options") mxg = group.add_mutually_exclusive_group() mxg.add_argument( "--hosted-wsclean", type=str, default="docker://alecthomson/wsclean:latest", help="Docker or Singularity image for wsclean", ) mxg.add_argument( "--local_wsclean", type=Path, default=None, help="Path to local wsclean Singularity image", ) return img_parser
[docs] def cli(): """Command-line interface""" im_parser = imager_parser(parent_parser=True) work_parser = workdir_arg_parser(parent_parser=True) parser = argparse.ArgumentParser( parents=[im_parser, work_parser], formatter_class=UltimateHelpFormatter, description=im_parser.description, ) args = parser.parse_args() main( msdir=args.msdir, out_dir=args.datadir, num_beams=args.num_beams, temp_dir_wsclean=args.temp_dir_wsclean, temp_dir_images=args.temp_dir_images, cutoff=args.psf_cutoff, robust=args.robust, pols=args.pols, nchan=args.nchan, size=args.size, scale=args.scale, mgain=args.mgain, niter=args.niter, nmiter=args.nmiter, local_rms=args.local_rms, local_rms_window=args.local_rms_window, auto_mask=args.auto_mask, force_mask_rounds=args.force_mask_rounds, auto_threshold=args.auto_threshold, minuv=args.minuv, purge=args.purge, taper=args.taper, parallel_deconvolution=args.parallel, gridder=args.gridder, wsclean_path=( Path(args.local_wsclean) if args.local_wsclean else args.hosted_wsclean ), multiscale=args.multiscale, multiscale_scale_bias=args.multiscale_scale_bias, multiscale_scales=args.multiscale_scales, ms_glob_pattern=args.ms_glob_pattern, data_column=args.data_column, skip_fix_ms=args.skip_fix_ms, no_mf_weighting=args.no_mf_weighting, disable_pol_local_rms=args.disable_pol_local_rms, disable_pol_force_mask_rounds=args.disable_pol_force_mask_rounds, )
if __name__ == "__main__": cli()