#!/usr/bin/env python
"""Pipeline and flow utility functions"""
from __future__ import annotations
import argparse
import base64
import logging
import shlex
import subprocess
import time
import warnings
from pathlib import Path
from uuid import UUID
import astropy.units as u
import dask.array as da
import dask.distributed as distributed
import numpy as np
from astropy.utils.exceptions import AstropyWarning
from dask.delayed import Delayed
from dask.distributed import get_client
from distributed.client import futures_of
from distributed.diagnostics.progressbar import ProgressBar
from distributed.utils import LoopRunner
from prefect import Task, task
from prefect.artifacts import create_markdown_artifact
from prefect.concurrency.sync import rate_limit
from prefect.futures import PrefectFuture
from prefect_dask import get_dask_client
from spectral_cube.utils import SpectralCubeWarning
from tornado.ioloop import IOLoop
from tqdm.auto import tqdm, trange
from arrakis.logger import TqdmToLogger, UltimateHelpFormatter, logger
warnings.filterwarnings(action="ignore", category=SpectralCubeWarning, append=True)
warnings.simplefilter("ignore", category=AstropyWarning)
[docs]
SUPPORTED_IMAGE_TYPES = ("png",)
[docs]
TQDM_OUT = TqdmToLogger(logger, level=logging.INFO)
# Help string to be shown using the -h option
[docs]
logo_str = """
mmm mmm mmm mmm mmm
)-( )-( )-( )-( )-(
( S ) ( P ) ( I ) ( C ) ( E )
| | | | | | | | | |
|___| |___| |___| |___| |___|
mmm mmm mmm mmm
)-( )-( )-( )-(
( R ) ( A ) ( C ) ( S )
| | | | | | | |
|___| |___| |___| |___|
"""
[docs]
def submit_task_with_rate_limit(task: Task, *args, **kwargs) -> PrefectFuture:
"""Submit a task with rate limiting
Args:
task (Task): Task to submit
*args: Arguments to pass to the task
**kwargs: Keyword arguments to pass to the task
Returns:
PrefectFuture: Future object
"""
rate_limit("settle-right-down", occupy=1)
return task.submit(*args, **kwargs)
# Stolen from Flint
@task(name="Upload image as artifact")
[docs]
def upload_image_as_artifact_task(
image_path: Path, description: str | None = None
) -> UUID:
"""Create and submit a markdown artifact tracked by prefect for an
input image. Currently supporting png formatted images.
The input image is converted to a base64 encoding, and embedded directly
within the markdown string. Therefore, be mindful of the image size as this
is tracked in the postgres database.
Args:
image_path (Path): Path to the image to upload
description (Optional[str], optional): A description passed to the markdown artifact. Defaults to None.
Returns:
UUID: Generated UUID of the registered artifact
"""
image_type = image_path.suffix.replace(".", "")
assert image_path.exists(), f"{image_path} does not exist"
assert image_type in SUPPORTED_IMAGE_TYPES, (
f"{image_path} has type {image_type}, and is not supported. Supported types are {SUPPORTED_IMAGE_TYPES}"
)
with image_path.open("rb") as open_image:
logger.info(f"Encoding {image_path} in base64")
image_base64 = base64.b64encode(open_image.read()).decode()
logger.info("Creating markdown tag")
markdown = f""
logger.info("Registering artifact")
image_uuid: UUID = create_markdown_artifact(
markdown=markdown, description=description
)
return image_uuid
[docs]
def workdir_arg_parser(parent_parser: bool = False) -> argparse.ArgumentParser:
# Parse the command line options
work_parser = argparse.ArgumentParser(
add_help=not parent_parser,
formatter_class=UltimateHelpFormatter,
)
parser = work_parser.add_argument_group("workdir arguments")
parser.add_argument(
"datadir",
type=Path,
help="Directory to create/find full-size images and 'cutout' directory",
)
return work_parser
[docs]
def generic_parser(parent_parser: bool = False) -> argparse.ArgumentParser:
descStr = f"""
{logo_str}
Generic pipeline options
"""
# Parse the command line options
gen_parser = argparse.ArgumentParser(
add_help=not parent_parser,
description=descStr,
formatter_class=UltimateHelpFormatter,
)
parser = gen_parser.add_argument_group("generic arguments")
parser.add_argument(
"field", metavar="field", type=str, help="Name of field (e.g. RACS_2132-50)."
)
parser.add_argument(
"--sbid",
type=int,
default=None,
help="SBID of observation.",
)
parser.add_argument(
"-s",
"--stokes",
dest="stokeslist",
nargs="+",
type=str,
default=["I", "Q", "U"],
help="List of Stokes parameters to image",
)
parser.add_argument(
"-e",
"--epoch",
type=int,
default=0,
help="Epoch of observation.",
)
parser.add_argument(
"-v", dest="verbose", action="store_true", help="Verbose output."
)
parser.add_argument(
"--host",
metavar="host",
type=str,
default=None,
help="Host of mongodb (probably $hostname -i).",
)
parser.add_argument(
"--username", type=str, default=None, help="Username of mongodb."
)
parser.add_argument(
"--password", type=str, default=None, help="Password of mongodb."
)
parser.add_argument(
"--limit",
type=int,
default=None,
help="Limit the number of islands to process.",
)
parser.add_argument(
"--database", dest="database", action="store_true", help="Add data to MongoDB."
)
return gen_parser
[docs]
def inspect_client(
client: distributed.Client | None = None,
) -> tuple[str, int, int, u.Quantity, int, u.Quantity]:
"""_summary_
Args:
client (Union[distributed.Client,None]): Dask client to inspect.
if None, will use the default client.
Returns:
Tuple[ str, int, int, u.Quantity, float, u.Quantity ]: addr, nworkers,
nthreads, memory, threads_per_worker, memory_per_worker
"""
"""Inspect a client"""
if client is None:
client = get_client()
logger.debug(f"Client: {client}")
info = client._scheduler_identity
addr = info.get("address")
workers = info.get("workers", {})
nworkers = len(workers)
nthreads = sum(w["nthreads"] for w in workers.values())
memory = sum([w["memory_limit"] for w in workers.values()]) * u.byte
threads_per_worker = nthreads // nworkers
memory_per_worker = memory / nworkers
return addr, nworkers, nthreads, memory, threads_per_worker, memory_per_worker
[docs]
def chunk_dask(
outputs: list,
batch_size: int = 10_000,
task_name="",
progress_text="",
verbose=True,
) -> list:
client = get_client()
chunk_outputs = []
for i in trange(
0, len(outputs), batch_size, desc=f"Chunking {task_name}", disable=(not verbose)
):
outputs_chunk = outputs[i : i + batch_size]
futures = client.persist(outputs_chunk)
# dumb solution for https://github.com/dask/distributed/issues/4831
if i == 0:
logger.debug("I sleep!")
time.sleep(10)
logger.debug("I awake!")
tqdm_dask(futures, desc=progress_text, disable=(not verbose), file=TQDM_OUT)
chunk_outputs.extend(futures)
return chunk_outputs
[docs]
def delayed_to_da(list_of_delayed: list[Delayed], chunk: int | None = None) -> da.Array:
"""Convert list of delayed arrays to a dask array
Args:
list_of_delayed (List[delayed]): List of delayed objects
chunk (int, optional): Chunksize to use. Defaults to None.
Returns:
da.Array: Dask array
"""
sample = list_of_delayed[0].compute()
dim = (len(list_of_delayed),) + sample.shape
if chunk is None:
c_dim = dim
else:
c_dim = (chunk,) + sample.shape
darray_list = [
da.from_delayed(lazy, dtype=sample.dtype, shape=sample.shape)
for lazy in list_of_delayed
]
darray = da.stack(darray_list, axis=0).reshape(dim).rechunk(c_dim)
return darray
# stolen from https://github.com/tqdm/tqdm/issues/278
[docs]
class TqdmProgressBar(ProgressBar):
"""Tqdm for Dask"""
def __init__(
self,
keys,
scheduler=None,
interval="100ms",
loop=None,
complete=True,
start=True,
**tqdm_kwargs,
):
super().__init__(keys, scheduler, interval, complete)
[docs]
self.tqdm = tqdm(keys, **tqdm_kwargs)
[docs]
self.loop = loop or IOLoop()
if start:
loop_runner = LoopRunner(self.loop)
loop_runner.run_sync(self.listen)
[docs]
def _draw_bar(self, remaining, all, **kwargs):
_ = kwargs
update_ct = (all - remaining) - self.tqdm.n
self.tqdm.update(update_ct)
[docs]
def _draw_stop(self, **kwargs):
_ = kwargs
self.tqdm.close()
[docs]
def tqdm_dask(futures_in: distributed.Future, **kwargs) -> None:
"""Tqdm for Dask futures"""
futures = futures_of(futures_in)
if not isinstance(futures, (set, list)):
futures = [futures]
TqdmProgressBar(futures, **kwargs)
[docs]
def port_forward(port: int, target: str) -> None:
"""Forward ports to local host
Args:
port (int): port to forward
target (str): Target host
"""
logger.info(f"Forwarding {port} from localhost to {target}")
cmd = f"ssh -N -f -R {port}:localhost:{port} {target}"
command = shlex.split(cmd)
_ = subprocess.Popen(command)
[docs]
def cpu_to_use(max_cpu: int, count: int) -> int:
"""Find number of cpus to use.
Find the right number of cpus to use when dividing up a task, such
that there are no remainders.
Args:
max_cpu (int): Maximum number of cores to use for a process.
count (int): Number of tasks.
Returns:
Maximum number of cores to be used that divides into the number
"""
factors = []
for i in range(1, count + 1):
if count % i == 0:
factors.append(i)
factors_arr = np.array(factors)
return np.max(factors_arr[factors_arr <= max_cpu])