#!/usr/bin/env python3
from __future__ import annotations
import logging
import astropy
import astropy.units as units
import matplotlib.pyplot as plt
import numpy as np
from arrakis.logger import TqdmToLogger, logger
from arrakis.utils.database import get_db
from astropy.coordinates import SkyCoord
from astropy.wcs import WCS
from tqdm.auto import tqdm, trange
[docs]
TQDM_OUT = TqdmToLogger(logger, level=logging.INFO)
[docs]
def makesurf(start, stop, field, datadir, save_plots=True, data=None):
# myquery = {'rmsynth1d': True}
query = {"$and": [{f"beams.{field}": {"$exists": True}}]}
_ = list(beams_col.find(query).sort("Source_ID"))
island_ids = sorted(beams_col.distinct("Source_ID", query))
query = {"Source_ID": {"$in": island_ids}}
# myquery = {'rmsynth1d': True}
components = list(comp_col.find(query).sort("Source_ID"))
ras, decs, freqs, stokeis, stokeqs, stokeus = [], [], [], [], [], []
specs = []
for i, comp in enumerate(tqdm(components, file=TQDM_OUT)):
iname = comp["Source_ID"]
cname = comp["Gaussian_ID"]
spectra = f"{datadir}/cutouts/{iname}/{cname}.dat"
if data is None:
try:
freq, iarr, qarr, uarr, rmsi, rmsq, rmsu = np.loadtxt(spectra).T
specs.append([freq, iarr, qarr, uarr, rmsi, rmsq, rmsu])
except Exception as e:
logger.warning(f"Could not find '{spectra}': {e}")
continue
else:
try:
freq, iarr, qarr, uarr, rmsi, rmsq, rmsu = data[i]
except IndexError:
continue
ras.append(comp["RA"])
decs.append(comp["Dec"])
freqs.append(np.nanmean(freq[start:stop]))
stokeis.append(np.nansum(iarr[start:stop]))
stokeqs.append(np.nansum(qarr[start:stop]))
stokeus.append(np.nansum(uarr[start:stop]))
ras = np.array(ras)
decs = np.array(decs)
stokeis = np.array(stokeis)
stokeqs = np.array(stokeqs)
stokeus = np.array(stokeus)
freqs = np.nanmean(np.array(freqs))
logger.debug("freq is ", freqs)
coords = SkyCoord(ras * units.deg, decs * units.deg)
wcs = WCS(
f"/group/askap/athomson/projects/RACS/CI0_mosaic_1.0/RACS_test4_1.05_{field}.fits"
)
x, y = wcs.celestial.world_to_pixel(coords)
# Parse out data
x_raw = x # ra
y_raw = y # dec
q_raw = stokeqs / stokeis # Q/I
u_raw = stokeus / stokeis # U/I
# Kill nans
good_idxs = (~np.isnan(q_raw)) & (~np.isnan(u_raw))
# Use good data only
x = x_raw[good_idxs]
y = y_raw[good_idxs]
q = q_raw[good_idxs]
u = u_raw[good_idxs]
# x = x/60**2 #get emil's pixel positions in degs
# y = y/60**2 #get emil's pixel positions in degs
p = np.sqrt(q.astype(float) ** 2 + u.astype(float) ** 2)
# Imports
from scipy.spatial import distance_matrix
# Settings
pixelscales = astropy.wcs.utils.proj_plane_pixel_scales(wcs)
dperpix = pixelscales[1]
d = (1 / 60) * 15 / dperpix # Radius for circular estimator sliding window aperture
trim_mean_frac = (
0.2 # frac of data points to chop from each end of frac Stokes value dist
)
# Minimum number of sources required in the sliding aperture to return a non-nan estimate of the local leakage
min_data_points_in_aperture = 5
grid_point_sep_deg = d / 4
# Define functions
def trim_mean(x):
from scipy import stats
return stats.trim_mean(x, trim_mean_frac)
# Positions of measured leakages
pos_measurements = np.array(list(zip(x, y)))
# Positions of grid points to derive leakage estimates at
xnew = np.arange(np.min(x), np.max(x) + grid_point_sep_deg, grid_point_sep_deg)
ynew = np.arange(np.min(y), np.max(y) + grid_point_sep_deg, grid_point_sep_deg)
logger.debug(len(xnew), len(ynew))
xxnew, yynew = np.meshgrid(xnew, ynew)
pos_estimator_grid = np.array([[a, b] for a in xnew for b in ynew])
# Calculate pair-wise distances between the two sets of coordinate pairs
logger.info("\nDeriving pair-wise distance matrix...")
pair_dist = distance_matrix(pos_estimator_grid, pos_measurements)
logger.info("Done.\n")
# Collect leakage values nearby each grid point
q_estimates = []
u_estimates = []
p_estimates = []
num_points_in_aperture_list = [] # Init collectors
logger.info("\nDeriving robust leakage estimates for interpolation grid...")
for row_idx, row in enumerate(tqdm(pair_dist, file=TQDM_OUT)):
# Guide to where we're at
# if row_idx%100==0:
# logger.info('Processing row %d of %d'%(row_idx,len(pair_dist)))
# idxs of points within d degs
idxs_of_points_in_aperture = np.argwhere(row < d)
# collect data points for sources in aperture
q_of_points_in_aperture = q[idxs_of_points_in_aperture]
u_of_points_in_aperture = u[idxs_of_points_in_aperture]
p_of_points_in_aperture = p[idxs_of_points_in_aperture]
# robust estimator of central value of dist
if len(q_of_points_in_aperture) >= min_data_points_in_aperture:
est_q_leak_of_points_in_aperture = trim_mean(q_of_points_in_aperture)
q_estimates.append(est_q_leak_of_points_in_aperture)
num_points_in_aperture_list.append(len(q_of_points_in_aperture))
else:
q_estimates.append(np.nan)
if len(u_of_points_in_aperture) >= min_data_points_in_aperture:
est_u_leak_of_points_in_aperture = trim_mean(u_of_points_in_aperture)
u_estimates.append(est_u_leak_of_points_in_aperture)
else:
u_estimates.append(np.nan)
if len(p_of_points_in_aperture) >= min_data_points_in_aperture:
est_p_leak_of_points_in_aperture = trim_mean(p_of_points_in_aperture)
p_estimates.append(est_p_leak_of_points_in_aperture)
else:
p_estimates.append(np.nan)
q_estimates_arr = np.array(q_estimates)
u_estimates_arr = np.array(u_estimates)
_ = np.array(p_estimates)
msg = f"The mean number of points in each aperture of {d:.2f} degs was {np.nanmean(num_points_in_aperture_list)}"
logger.info(msg)
# plot results
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(
111,
)
# q_leakage_map = np.rot90(q_estimates_arr.reshape((len(xnew),len(ynew))).astype(float),k=3)
q_leakage_map = np.rot90(
q_estimates_arr.reshape((len(xnew), len(ynew))).astype(float), k=3
)
im = ax.imshow(
q_leakage_map, origin="lower", vmin=-0.05, vmax=0.05, cmap="coolwarm"
)
fig.colorbar(im, label="Q/I")
ax.set_aspect("equal", "box")
ax.invert_xaxis()
if save_plots:
plt.savefig(f"q_leakage_{field}.png")
# plt.xlim(-10,60)
# plt.ylim(-10,40)
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(
111,
)
u_leakage_map = np.rot90(
u_estimates_arr.reshape((len(xnew), len(ynew))).astype(float), k=3
)
im = ax.imshow(
u_leakage_map, origin="lower", vmin=-0.05, vmax=0.05, cmap="coolwarm"
)
fig.colorbar(im, label="U/I")
ax.set_aspect("equal", "box")
ax.invert_xaxis()
if save_plots:
plt.savefig(f"u_leakage_{field}.png")
# plt.xlim(-10,60)
# plt.ylim(-10,40)
return freqs, q_leakage_map, u_leakage_map, specs, wcs
[docs]
def main(field, datadir, username="admin", password=None):
global beams_col
global island_col
global comp_col
beams_col, island_col, comp_col = get_db(
"146.118.68.63", username=username, password=password
)
start = 0
stop = -1
freqs, q_leakage_map, u_leakage_map, data, wcs = makesurf(
start, stop, field, datadir, save_plots=True
)
start = 0
f_big, q_big, u_big = [], [], []
for i in trange(6):
stop = start + (288 // 6) - 1
f, q, u, _, _ = makesurf(
start, stop, field, datadir, save_plots=False, data=data
)
f_big.append(f)
q_big.append(q)
u_big.append(u)
start += (288 // 6) - 1
f_big = np.array(f_big)
q_big = np.array(q_big)
u_big = np.array(u_big)
fig, ax = plt.subplots(
2,
6,
figsize=(18, 6),
)
lim = 0.1
for i, (f, q, u) in enumerate(zip(f_big, q_big, u_big)):
ax[0, i].imshow(q, origin="lower", vmin=-lim, vmax=lim, cmap=plt.cm.coolwarm)
ax[0, i].axis("off")
ax[0, i].invert_xaxis()
ax[0, i].set_title(f"{f / 1e6:0.1f}MHz")
ax[1, i].imshow(
u,
origin="lower",
cmap=plt.cm.coolwarm,
vmin=-lim,
vmax=lim,
)
ax[1, i].axis("off")
ax[1, i].invert_xaxis()
ax[0, 0].text(-20, 50, "Q")
ax[1, 0].text(-20, 50, "U")
# plt.subplots_adjust(hspace=0)
plt.savefig(f"{field}_leakages.png")
[docs]
def cli():
import argparse
import getpass
# Help string to be shown using the -h option
descStr = """
Make leakage plots for a field.
"""
# Parse the command line options
parser = argparse.ArgumentParser(
description=descStr, formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"field", metavar="field", type=str, help="RACS field to mosaic - e.g. 2132-50A."
)
parser.add_argument(
"datadir",
metavar="datadir",
type=str,
help="Directory containing cutouts (in subdir outdir/cutouts)..",
)
args = parser.parse_args()
password = getpass.getpass()
main(field=args.field, datadir=args.datadir, username="admin", password=password)
if __name__ == "__main__":
cli()