import os
from copy import deepcopy
import cmasher as cmr
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import scipy.constants as sc
from astropy.convolution import Gaussian2DKernel, convolve_fft
from astropy.io import fits
from astropy.wcs import WCS
from matplotlib.patches import Ellipse
from scipy import ndimage
from warnings import catch_warnings, simplefilter
FWHM_to_sigma = 1.0 / np.sqrt(8.0 * np.log(2)) # 1/2.355 FWHM = sigma * sqrt(8ln(2))
arcsec = np.pi / 648000
default_cmap = cmr.arctic
[docs]
class Cube:
"""
A class to handle astronomy data cubes. This class provides functionality to read, manipulate,
and analyze 3D FITS data cubes commonly used in (radio) astronomy.
The cube data is expected to have 3 dimensions:
- Two spatial dimensions (RA, Dec)
- One spectral dimension (frequency or velocity)
The class provides methods for:
- Reading FITS files
- Basic cube manipulation (cropping, smoothing, etc.)
- Moment map generation
- Spectral analysis
- Visualization
Attributes:
filename (str): Path to the input FITS file
header (astropy.io.fits.Header): FITS header information
data (numpy.ndarray): 3D array containing the cube data
wcs (astropy.wcs.WCS): World Coordinate System information
object (str): Name of the astronomical object
unit (str): Units of the data values
"""
[docs]
def __init__(self, filename, only_header=False, correct_factor=None, unit=None, pixelscale=None, restfreq=None, zoom=None, **kwargs):
"""
Initialize a Cube object.
Args:
filename (str): The path to the FITS file containing the cube data.
only_header (bool): If True, only read the header and not the data.
"""
self.filename = os.path.normpath(os.path.expanduser(filename))
self._read(**kwargs, only_header=only_header, correct_factor=correct_factor, unit=unit, pixelscale=pixelscale, restfreq=restfreq, zoom=zoom)
[docs]
def _read(self, only_header=False, correct_factor=None, unit=None, pixelscale=None, instrument=None, restfreq=None, zoom=None):
"""
Read the FITS file and initialize the Cube object.
Args:
only_header (bool): If True, only read the header and not the data.
correct_factor (array-like): Array of correction factors for the data.
unit (str): The unit of the data.
pixelscale (float): The pixel scale in arcsec.
instrument (str): The instrument used to acquire the data.
restfreq (float): The rest frequency of the data in Hz.
zoom (float): The zoom factor for the data.
"""
try:
hdu = fits.open(self.filename)
self.header = hdu[0].header
# Read a few keywords in header
try:
self.object = hdu[0].header['OBJECT']
except:
print("Warning: could not find header keyword OBJECT")
self.object = ""
try:
self.unit = hdu[0].header['BUNIT']
except:
print("Warning: could not find header keyword BUNIT")
self.unit = ""
if unit is not None:
print("Warning: forcing unit")
self.unit=unit
if self.unit == "beam-1 Jy": # discminer format
self.unit = "Jy/beam"
if self.unit == "JY/PIXEL": # radmc format
self.unit = "Jy pixel-1"
# Suppress warnings when creating WCS
with catch_warnings():
simplefilter("ignore") # Ignore all warnings
try:
self.wcs = WCS(self.header).celestial
except Exception as e:
print(f"Warning: could not create WCS - {e}")
self.wcs = None
# pixel info
self.nx = hdu[0].header['NAXIS1']
self.ny = hdu[0].header['NAXIS2']
# Check for CD matrix (WCS rotation) or use CDELT
self.wcs_rotation = 0.0 # rotation angle in degrees to align north with y-axis
try:
try:
self.pixelscale = hdu[0].header['CDELT2'] * 3600 # arcsec
except KeyError:
self.pixelscale = abs(hdu[0].header['CD2_2']) * 3600
# Check for CD matrix keywords for pixel scale and rotation
cd_keys = {'CD1_1', 'CD1_2', 'CD2_1', 'CD2_2'}
has_cd_matrix = cd_keys.issubset(hdu[0].header)
if has_cd_matrix:
CD1_1 = hdu[0].header['CD1_1']
CD1_2 = hdu[0].header['CD1_2']
CD2_1 = hdu[0].header['CD2_1']
CD2_2 = hdu[0].header['CD2_2']
# Calculate pixel scale from CD matrix (in degrees, convert to arcsec)
pix_scale_y = np.sqrt(CD2_1**2 + CD2_2**2) * 3600.0 # arcsec
pix_scale_x = np.sqrt(CD1_1**2 + CD1_2**2) * 3600.0 # arcsec
# Use y-scale (typically they're similar)
self.pixelscale = pix_scale_y
# Determine rotation: ORIENTAT is preferred over CD matrix
if 'ORIENTAT' in hdu[0].header:
# ORIENTAT is position angle of image y axis (deg. e of n)
# To align north with y-axis, rotate by -ORIENTAT
self.wcs_rotation = -hdu[0].header['ORIENTAT']
elif has_cd_matrix:
# Calculate rotation angle from CD matrix if ORIENTAT is not present
# PA = arctan2(CD1_2, CD2_2) gives the angle of the y-axis east of north
# To align north with y-axis, we rotate by -PA
self.wcs_rotation = -np.rad2deg(np.arctan2(CD1_2, CD2_2))
self.cx = hdu[0].header['CRPIX1']
self.cy = hdu[0].header['CRPIX2']
self.x_ref = hdu[0].header['CRVAL1'] # coordinate
self.y_ref = hdu[0].header['CRVAL2']
except:
print("Warning: missing WCS")
self.cx = self.nx//2 + 1
self.cy = self.ny//2 + 1
self.x_ref = 0
self.y_ref = 0
# SPHERE pixelscales (Maire et al 2016)
if instrument == "IFS":
pixelscale = 7.46e-3
elif instrument == "IRDIS_Y2":
pixelscale = 12.283e-3
elif instrument == "IRDIS_Y3":
pixelscale = 12.283e-3
elif instrument == "IRDIS_J2":
pixelscale = 12.266e-3
elif instrument == "IRDIS_J3":
pixelscale = 12.261e-3
elif instrument == "IRDIS_H2":
pixelscale = 12.255e-3
elif instrument == "IRDIS_H3":
pixelscale = 12.250e-3
elif instrument == "IRDIS_K1":
pixelscale = 12.267e-3
elif instrument == "IRDIS_K2":
pixelscale = 12.263e-3
elif instrument == "IRDIS_BB_J":
pixelscale = 12.263e-3
elif instrument == "IRDIS_BB_H":
pixelscale = 12.251e-3
elif instrument == "IRDIS_BB_Ks":
pixelscale = 12.265e-3
if pixelscale is None:
raise ValueError("Please provide pixelscale (or instrument)")
self.pixelscale = pixelscale
print(f"Pixel scale set to {self.pixelscale} mas")
self.FOV = np.maximum(self.nx, self.ny) * self.pixelscale
# image axes : with 0, 0 assumed as the center of the image
# (Need to add self.x_ref or y_ref for full coordinates)
self.xaxis = -(np.arange(1, self.nx + 1) - self.cx) * self.pixelscale
self.yaxis = (np.arange(1, self.ny + 1) - self.cy) * self.pixelscale
# velocity axis
try:
self.nv = hdu[0].header['NAXIS3']
except:
self.nv = 1
try:
self.restfreq = hdu[0].header['RESTFRQ']
self.wl = sc.c / self.restfreq
except:
try:
self.restfreq = hdu[0].header['RESTFREQ'] # gildas format
self.wl = sc.c / self.restfreq
except:
if restfreq is None:
print("Warning: missing rest frequency")
else:
self.restfreq = restfreq
self.wl = sc.c / self.restfreq
try:
self.velocity_type = hdu[0].header['CTYPE3']
self.CRPIX3 = hdu[0].header['CRPIX3']
self.CRVAL3 = hdu[0].header['CRVAL3']
self.CDELT3 = hdu[0].header['CDELT3']
if self.velocity_type == "VELO-LSR" or self.velocity_type == "VRAD": # gildas and casa
try:
self.CUNIT3 = hdu[0].header['CUNIT3']
except:
self.CUNIT3 = None
if self.CUNIT3 and self.CUNIT3.upper() == "M/S":
factor = 1e-3
elif self.CUNIT3 and self.CUNIT3.upper() == "KM/S":
factor = 1
else:
if self.CDELT3 < 5: # assuming km/s
print("Assuming velocity axis is in km/s")
factor = 1
else: # assuming m/s
factor = 1e-3
print("Assuming velocity axis is in m/s")
self.velocity = (self.CRVAL3 + self.CDELT3 * (np.arange(1, self.nv + 1) - self.CRPIX3)) * factor # km/s
self.nu = self.restfreq * (1 - self.velocity * 1000 / sc.c)
elif self.velocity_type == "FREQ" or self.velocity_type=="FREQ-LSR": # Hz
self.nu = self.CRVAL3 + self.CDELT3 * (np.arange(1, self.nv + 1) - self.CRPIX3)
self.velocity = (-(self.nu - self.restfreq) / self.restfreq * sc.c / 1000.0) # km/s
else:
raise ValueError("Velocity type is not recognised:", self.velocity_type)
self.is_V = True
except:
print("Warning: could not extract velocity")
self.is_V = False
# beam
try:
self.bmaj = hdu[0].header['BMAJ'] * 3600 # arcsec
self.bmin = hdu[0].header['BMIN'] * 3600
self.bpa = hdu[0].header['BPA']
except:
try:
# make an average of all the records ...
self.bmaj = hdu[1].data[0][0]
self.bmin = hdu[1].data[0][1]
self.bpa = hdu[1].data[0][2]
except:
print("Warning: missing beam")
self.bmaj = 0
self.bmin = 0
self.bpa = 0
# reading data
if not only_header:
self.image = np.ma.masked_array(hdu[0].data)
if self.image.ndim == 4:
self.image = self.image[0, :, :, :]
if self.image.ndim == 3 and self.nv == 1:
self.image = self.image[0, :, :]
if correct_factor is not None:
self.image *= correct_factor[:,np.newaxis, np.newaxis]
if zoom is not None:
print('Original size = ', self.image.shape)
self.image[np.isnan(self.image)] = 0.
self.image = ndimage.zoom(self.image, zoom=[1, zoom, zoom])
print('Resampled size = ', self.image.shape)
nx = self.image.shape[-1]
ny = self.image.shape[-2]
self.pixelscale = self.pixelscale*self.nx/nx
self.nx = nx
self.ny = ny
hdu.close()
except OSError:
print('cannot open', self.filename)
return ValueError
[docs]
def cutout(self, filename, FOV=None, ix_min=None, ix_max=None, iv_min=None, iv_max=None, vmin=None, vmax=None, no_pola=False, pmin=None, pmax=None, channels=None, **kwargs):
"""
Cut out a region (in space, frequency and polarization) of the cube and save it as a new FITS file.
Args:
filename (str): The path to the output FITS file.
FOV (float): The field of view in arcsec.
ix_min (int): The minimum index of the x-axis.
ix_max (int): The maximum index of the x-axis.
iv_min (int): The minimum index of the velocity axis.
iv_max (int): The maximum index of the velocity axis.
vmin (float): The minimum velocity in km/s.
vmax (float): The maximum velocity in km/s.
no_pola (bool): If True, only keep the first polarization.
pmin (int): The minimum index of the polarization axis.
pmax (int): The maximum index of the polarization axis.
channels (list): The list of channels to keep.
**kwargs: Additional keyword arguments for the fits.writeto function.
"""
image = deepcopy(self.image)
header = deepcopy(self.header)
ndim = image.ndim
# ---- Spatial trimming
if FOV is not None:
cutout_pix = int(FOV / self.pixelscale)-1
while cutout_pix * self.pixelscale < FOV:
cutout_pix += 1
excess_pix = int(0.5 * (self.nx - cutout_pix))
ix_min=excess_pix
ix_max=excess_pix+cutout_pix
# We do not trim by default
if ix_min is None:
ix_min=0
if ix_max is None:
ix_max=self.nx
# forcing square image for now
iy_min = ix_min
iy_max = ix_max
# Updating header
if header['CRPIX1']-1 >= ix_min: # we keep the same pixel, and adjust its index
header['CRPIX1'] -= ix_min
else: # we use the first pixel, and update its coordinates
header['CRVAL1'] += (ix_min - header['CRPIX1']-1) * header['CDELT1']
header['CRPIX1'] = 1
header['CRPIX2'] = ix_max-ix_min
if header['CRPIX2']-1 >= iy_min: # we keep the same pixel, and adjust its index
header['CRPIX2'] -= iy_min
else: # we use the first pixel, and update its coordinates
header['CRVAL2'] += (iy_min - header['CRPIX2']-1) * header['CDELT2']
header['CRPIX2'] = iy_max-iy_min
# ---- Spectral trimming
if vmin is not None:
iv_min = np.argmin(np.abs(self.velocity - vmin))
if vmax is not None:
iv_max = np.argmin(np.abs(self.velocity - vmax))
# We do not trim by default
if iv_min is None:
iv_min=0
if iv_max is None:
iv_max=self.nv
# if the velocity axis is flipped
if iv_min > iv_max:
temp = iv_min
iv_min = iv_max
iv_max = temp
# Updating header
if header['CRPIX3']-1 >= iv_min: # we keep the same pixel, and adjust its index
header['CRPIX3'] -= iv_min
else: # we use the first pixel, and update its coordinates
header['CRVAL3'] += (iv_min - (header['CRPIX3']-1)) * header['CDELT3']
header['CRPIX3'] = 1
header['NAXIS3'] = iv_max - iv_min
# --- Polarisation trimming
if ndim > 3:
if no_pola:
pmin = 0
pmax = 0
header['NAXIS4'] = 1
# trimming cube
if image.ndim == 4:
image = image[pmin:pmax,iv_min:iv_max,iy_min:iy_max,ix_min:ix_max]
elif image.ndim == 3:
image = image[iv_min:iv_max,iy_min:iy_max,ix_min:ix_max]
else:
raise ValueError("incorrect dimension in fits file")
fits.writeto(os.path.normpath(os.path.expanduser(filename)),image.data, header, **kwargs)
return
[docs]
def tapered_fits(self, filename, taper=None, **kwargs):
"""
Create a tapered version of a fits file.
Args:
filename (str): The path to the output FITS file.
taper (float): The taper size in arcsec.
**kwargs: Additional keyword arguments for the fits.writeto function.
"""
if taper is None:
raise ValueError("taper is needed")
image = deepcopy(self.image)
header = deepcopy(self.header)
ndim = image.ndim
nv=image.shape[0]
if taper < self.bmaj:
print("taper is smaller than bmaj=", self.bmaj)
print("No taper applied")
else:
delta_bmaj = np.sqrt(taper ** 2 - self.bmaj ** 2)
bmaj = taper
delta_bmin = np.sqrt(taper ** 2 - self.bmin ** 2)
bmin = taper
sigma_x = delta_bmin / self.pixelscale * FWHM_to_sigma # in pixels
sigma_y = delta_bmaj / self.pixelscale * FWHM_to_sigma # in pixels
print("beam = ", self.bmaj, self.bmin)
print("tapper =", delta_bmaj, delta_bmin, self.bpa)
print("beam = ",np.sqrt(self.bmaj ** 2 + delta_bmaj ** 2),np.sqrt(self.bmin ** 2 + delta_bmin ** 2))
beam = Gaussian2DKernel(sigma_x, sigma_y, self.bpa * np.pi / 180)
for iv in range(nv):
im = deepcopy(self.image[iv,:,:])
im = convolve_fft(im, beam)
# Correcting for beam area. Tested ok
im *= taper**2/(self.bmin * self.bmaj)
image[iv,:,:] = im
header['BMAJ'] = taper/3600
header['BMIN'] = taper/3600
fits.writeto(os.path.normpath(os.path.expanduser(filename)),image.data, header, **kwargs, overwrite=True)
return
[docs]
def sensitivity(self):
"""
Calculate the sensitivity of the cube (in K).
"""
self.get_std()
T = self._Jybeam_to_Tb(self.std,RJ=True)
return T
[docs]
def writeto(filename, image, header, **kwargs):
"""
Write the cube to a FITS file.
"""
fits.writeto(os.path.normpath(os.path.expanduser(filename)),image.data, header, **kwargs)
[docs]
def plot(
self,
iv=None,
v=None,
colorbar=True,
plot_beam=True,
color_scale=None,
fmin=None,
fmax=None,
limit=None,
limits=None,
moment=None,
moment_fname=None,
vturb = False,
Tb=False,
cmap=None,
v0=None,
dv=None,
ax=None,
no_ylabel=False,
no_xlabel=False,
no_vlabel=False,
title=None,
alpha=1.0,
interpolation="bicubic",
resample=0,
bmaj=None,
bmin=None,
bpa=None,
taper=None,
colorbar_label=True,
colorbar_side="right",
M0_threshold=None,
M8_threshold=None,
threshold = None,
threshold_value=np.nan,
vlabel_position="bottom",
vlabel_color="white",
vlabel_size=8,
shift_dx=0,
shift_dy=0,
mol_weight=None,
iv_support=None,
v_minmax = None,
axes_unit = "arcsec",
quantity_name=None,
stellar_mask = None,
levels=4,
plot_type="imshow",
linewidths=None,
zorder=None,
per_arcsec2=False,
colors=None,
x_beam = 0.125,
y_beam = 0.125,
mJy=False,
width=None,
highpass_filter=0,
normalise=False,
dynamic_range=None,
hpf=False,
RJ=True,
**kwargs
):
"""
Plotting routine for continuum image, moment maps and channel maps.
Args:
iv (int): The index of the velocity channel to plot.
v (float): The velocity in km/s.
colorbar (bool): Whether to plot the colorbar.
plot_beam (bool): Whether to plot the beam.
color_scale (str): The color scale to use.
fmin (float): The minimum value of the color scale.
fmax (float): The maximum value of the color scale.
limit (float): The limit of the color scale.
limits (list): The limits of the color scale.
moment (int): The moment to plot.
moment_fname (str): The filename of the moment map.
vturb (bool): Whether to plot the turbulent velocity.
Tb (bool): Whether to plot the brightness temperature.
cmap (str): The colormap to use.
v0 (float): The central velocity.
dv (float): The velocity width.
ax (matplotlib.axes.Axes): The axes to plot on.
no_ylabel (bool): Whether to hide the y-axis label.
no_xlabel (bool): Whether to hide the x-axis label.
no_vlabel (bool): Whether to hide the velocity label.
title (str): The title of the plot.
alpha (float): The alpha value of the plot.
interpolation (str): The interpolation method to use.
resample (int): The resampling factor.
bmaj (float): The major axis of the beam.
bmin (float): The minor axis of the beam.
bpa (float): The position angle of the beam.
taper (float): The taper size.
colorbar_label (bool): Whether to show the colorbar label.
colorbar_side (str): The side of the colorbar to plot.
M0_threshold (float): The threshold for the M0 moment map.
M8_threshold (float): The threshold for the M8 moment map.
threshold (float): The threshold for the plot.
threshold_value (float): The value to use for the threshold.
vlabel_position (str): The position of the velocity label.
vlabel_color (str): The color of the velocity label.
vlabel_size (int): The size of the velocity label.
shift_dx (float): The shift in the x-direction.
shift_dy (float): The shift in the y-direction.
mol_weight (str): The molecular weight to use.
iv_support (list): The indices of the velocity channels to support.
v_minmax (list): The minimum and maximum velocities to plot.
axes_unit (str): The unit of the axes.
quantity_name (str): The name of the quantity to plot.
stellar_mask (str): The filename of the stellar mask.
levels (int): The number of levels for the contour plot.
plot_type (str): The type of plot to make.
linewidths (list): The linewidths for the contour plot.
zorder (int): The zorder for the plot.
per_arcsec2 (bool): Whether to plot the quantity per arcsec^2.
colors (list): The colors for the contour plot.
x_beam (float): The x-axis beam size.
y_beam (float): The y-axis beam size.
mJy (bool): Whether to plot the quantity in mJy.
width (float): The width of the plot.
highpass_filter (float): The highpass filter to apply.
normalise (bool): Whether to normalise the plot.
dynamic_range (float): The dynamic range of the plot.
hpf (bool): Whether to apply a highpass filter.
**kwargs: Additional keyword arguments for the plot.
"""
if ax is None:
ax = plt.gca()
# Automatically check if the plot has been opened with projection='wcs'
use_wcs = hasattr(ax, 'coords') and ax.coords is not None
unit = self.unit
if self.nv == 1: # continuum image
is_cont = True
if self.image.ndim > 2:
im = self.image[0, :, :]
else:
im = self.image
_color_scale = 'log'
elif moment is not None:
is_cont = False
if moment_fname is not None:
hdu = fits.open(moment_fname)
im = hdu[0].data
else:
if hpf:
im = self.get_high_pass_filter_map(moment=moment, v0=v0, M0_threshold=M0_threshold, M8_threshold=M8_threshold, threshold=threshold, iv_support=iv_support, v_minmax=v_minmax, **kwargs)
else:
im = self.get_moment_map(moment=moment, v0=v0, M0_threshold=M0_threshold, M8_threshold=M8_threshold, threshold=threshold, iv_support=iv_support, v_minmax=v_minmax)
_color_scale = 'lin'
elif vturb:
is_cont = False
im = self.get_vturb(M0_threshold=M0_threshold, threshold=threshold, mol_weight=mol_weight)
_color_scale = 'lin'
else:
if self.is_V:
is_cont = False
# -- Selecting channel corresponding to a given velocity
if dv is not None:
v = v0 + dv
if v is not None:
iv = np.abs(self.velocity - v).argmin()
print("Selecting channel #", iv)
if iv is None:
print("Channel or velocity needed")
return ValueError
else:
is_cont = True
# Averaging multiple channels
v_offset = 0.0
try:
dv = np.diff(self.velocity)[0]
except:
dv = None
if width is not None:
n_channels = np.maximum(int(np.round(width/dv)),1)
if n_channels%2: # odd number of channels, same central channel
delta_iv = n_channels//2
iv_min=iv - delta_iv
iv_max=iv + delta_iv+1
else: # We will have a small shift compared to initial channel
delta_iv = n_channels//2
iv_min=iv - delta_iv
iv_max=iv + delta_iv
v_offset = -0.5*dv
print("Averaging between channels", iv_min, "and", iv_max-1, "(included). Width is", self.velocity[iv_max]-self.velocity[iv_min],"km/s")
im = np.average(self.image[iv_min:iv_max, :, :],axis=0)
else: # 1 single channel
im = self.image[iv, :, :]
_color_scale = 'lin'
# --- Convolution by taper
if taper is not None:
if taper < self.bmaj:
print("taper is smaller than bmaj=", self.bmaj)
print("No taper applied")
else:
delta_bmaj = np.sqrt(taper ** 2 - self.bmaj ** 2)
bmaj = taper
delta_bmin = np.sqrt(taper ** 2 - self.bmin ** 2)
bmin = taper
sigma_x = delta_bmin / self.pixelscale * FWHM_to_sigma # in pixels
sigma_y = delta_bmaj / self.pixelscale * FWHM_to_sigma # in pixels
print("beam = ", self.bmaj, self.bmin)
print("tapper =", delta_bmaj, delta_bmin, self.bpa)
print("beam = ",np.sqrt(self.bmaj ** 2 + delta_bmaj ** 2),np.sqrt(self.bmin ** 2 + delta_bmin ** 2))
beam = Gaussian2DKernel(sigma_x, sigma_y, self.bpa * np.pi / 180)
im = convolve_fft(im, beam)
# Correcting for beam area
im *= taper**2/(self.bmin * self.bmaj)
if Tb:
im = self._Jybeam_to_Tb(im,RJ=RJ)
unit = "K"
#if unit == "Jy/beam":
# im = self._Jybeam_to_Tb(im)
# unit = "K"
#else:
# print("Unknown unit, don't know kow to convert to Tb")
# return ValueError
_color_scale = 'lin'
if mJy:
im = im*1e3
unit = unit.replace("Jy","mJy")
if per_arcsec2:
im = im / self.pixelscale**2
unit = unit.replace("pixel-1", "arcsec-2")
if normalise:
im = im/np.max(im)
# --- resampling
if resample > 0:
mask = ndimage.zoom(im.mask * 1, resample, order=3)
im = ndimage.zoom(im.data, resample, order=3)
im = np.ma.masked_where(mask > 0.0, im)
# -- default color scale
if color_scale is None:
color_scale = _color_scale
# --- Cuts
if fmax is None:
fmax = np.nanmax(im)
if fmin is None:
if color_scale == 'log':
fmin = fmax * 1e-2
else:
fmin = 0.0
if dynamic_range is not None:
fmin = fmax * dynamic_range
# -- set up the color scale
print("fminmax=", fmin, fmax)
if color_scale == 'log':
norm = mcolors.LogNorm(vmin=fmin, vmax=fmax, clip=True)
elif color_scale == 'lin':
norm = mcolors.Normalize(vmin=fmin, vmax=fmax, clip=True)
elif color_scale == 'sqrt':
norm = mcolors.PowerNorm(0.5, vmin=fmin, vmax=fmax, clip=True)
elif color_scale == 'asinh':
norm = mcolors.AsinhNorm(vmin=0.5*fmax, vmax=fmax, clip=True)
else:
raise ValueError("Unknown color scale: " + color_scale)
if cmap is None:
if moment in [1, 9]:
cmap = "RdBu_r"
else:
cmap = default_cmap
if colors is not None:
cmap = None
# option to use WCS coordinates
if use_wcs:
xlabel = 'RA (J2000)'
ylabel = 'Dec (J2000)'
ax.coords[0].set_axislabel(xlabel)
ax.coords[1].set_axislabel(ylabel)
ax.set_frame_on(False)
extent = None
else:
if axes_unit.lower() == 'arcsec':
pix_scale = self.pixelscale
xlabel = r'$\Delta$ RA (")'
ylabel = r'$\Delta$ Dec (")'
xaxis_factor = -1
elif axes_unit.lower() == 'au':
pix_scale = self.pixelscale * self.P.map.distance
xlabel = 'Distance from star (au)'
ylabel = 'Distance from star (au)'
xaxis_factor = 1
elif axes_unit.lower() == 'pixels' or axes_unit.lower() == 'pixel':
pix_scale = 1
xlabel = r'x (pix)'
ylabel = r'y (pix)'
xaxis_factor = 1
else:
raise ValueError("Unknown unit for axes_units: " + axes_unit)
halfsize = np.asarray(im.shape) / 2 * pix_scale
extent = [-halfsize[1]*xaxis_factor-shift_dx, halfsize[1]*xaxis_factor-shift_dx, -halfsize[0]-shift_dy, halfsize[0]-shift_dy]
if axes_unit.lower() == 'pixels' or axes_unit.lower() == 'pixel':
extent = None
self.extent = extent
if threshold is not None:
im = np.where(im > threshold, im, threshold_value)
if highpass_filter > 0:
sigma = self.bmaj / self.pixelscale * FWHM_to_sigma * highpass_filter
kernel = Gaussian2DKernel(sigma, sigma, 0)
im -= convolve_fft(im, kernel)
# --- Apply WCS rotation if needed (to align north with y-axis)
if hasattr(self, 'wcs_rotation') and abs(self.wcs_rotation) > 1e-6:
# Rotate image to align north with y-axis
# scipy.ndimage.rotate rotates counter-clockwise for positive angles
im = ndimage.rotate(im, self.wcs_rotation, reshape=False, order=1, mode='constant', cval=0.0)
if plot_type=="imshow":
image = ax.imshow(
im,
norm=norm,
extent=extent,
origin='lower',
cmap=cmap,
alpha=alpha,
interpolation=interpolation,
zorder=zorder
)
elif plot_type=="contourf":
image = ax.contourf(
im,
extent=extent,
origin='lower',
levels=levels,
cmap=cmap,
linewidths=linewidths,
alpha=alpha,
zorder=zorder
)
elif plot_type=="contour":
image = ax.contour(
im,
extent=extent,
origin='lower',
levels=levels,
cmap=cmap,
linewidths=linewidths,
alpha=alpha,
zorder=zorder,
colors=colors
)
if limit is not None:
limits = [limit, -limit, -limit, limit]
if limits is not None:
if use_wcs:
wcs = WCS(self.header).celestial
# Convert world coordinates (RA/Dec) to pixel coordinates
pixels = wcs.wcs_world2pix(
[[limits[0], limits[2]], [limits[1], limits[3]]],
0
)
x1, y1 = pixels[0] # First point (ra_max, dec_min)
x2, y2 = pixels[1] # Second point (ra_min, dec_max)
ax.set_xlim(x1, x2)
ax.set_ylim(y1, y2)
else:
ax.set_xlim(limits[0], limits[1])
ax.set_ylim(limits[2], limits[3])
if not no_xlabel:
ax.set_xlabel(xlabel)
if not no_ylabel:
ax.set_ylabel(ylabel)
if title is not None:
ax.set_title(title)
# -- Color bar
if colorbar:
#divider = make_axes_locatable(ax)
#cax = divider.append_axes("right", size="5%", pad=0.05)
#cb = plt.colorbar(image, cax=cax, extend=colorbar_extend)
cb = add_colorbar(image,side=colorbar_side)
# cax,kw = mpl.colorbar.make_axes(ax)
# cb = plt.colorbar(image,cax=cax, **kw)
formatted_unit = unit.replace("-1", "$^{-1}$").replace("-2", "$^{-2}$")
if colorbar_label:
if moment == 0:
cb.set_label("Flux (" + formatted_unit + r"$\,$km$\,$s$^{-1}$)")
elif moment in [1, 9]:
cb.set_label(r"Velocity (km$\,$s$^{-1})$")
elif moment == 2:
cb.set_label(r"Velocity dispersion (km$\,$s$^{-1}$)")
else:
if Tb:
cb.set_label(r"T$_\mathrm{B}$ (" + formatted_unit + ")")
else:
if quantity_name is None:
quantity_name = "Flux"
if len(formatted_unit) > 0:
formatted_unit = " (" + formatted_unit + ")"
cb.set_label(quantity_name+formatted_unit)
plt.sca(ax) # we reset the main axis
# -- Adding velocity
if vlabel_position == "top":
y_vlabel = 0.85
x_vlabel = 0.5
elif vlabel_position == "top-right":
y_vlabel = 0.85
x_vlabel = 0.7
elif vlabel_position == "top-left":
y_vlabel = 0.85
x_vlabel = 0.25
else:
y_vlabel = 0.1
x_vlabel = 0.5
if not no_vlabel:
if (moment is None) and not is_cont and not vturb:
if v0 is None:
ax.text(
x_vlabel,
y_vlabel,
f"v={self.velocity[iv]+v_offset:<4.2f}" r"$\,$km/s",
horizontalalignment='center',
color=vlabel_color,
transform=ax.transAxes,
fontsize=vlabel_size
)
else:
ax.text(
x_vlabel,
y_vlabel,
r"$\Delta$v=" f"{self.velocity[iv] -v0:<4.2f}" r"$\,$km/s",
horizontalalignment='center',
color="white",
transform=ax.transAxes,
fontsize=vlabel_size
)
# --- Adding beam
if plot_beam:
# In case the beam is wrong in the header, when can pass the correct one
if bmaj is None:
bmaj = self.bmaj
if bmin is None:
bmin = self.bmin
if bpa is None:
bpa = self.bpa
beam = Ellipse(
ax.transLimits.inverted().transform((x_beam, y_beam)),
width=bmin,
height=bmaj,
angle=-bpa,
fill=True,
color="grey",
)
ax.add_patch(beam)
# Adding mask to hide star
if stellar_mask is not None:
dx = 0.5
dy = 0.5
mask = Ellipse(
ax.transLimits.inverted().transform((dx, dy)),
width=2 * stellar_mask,
height=2 * stellar_mask,
fill=True,
color='grey',
)
ax.add_patch(mask)
#-- Saving the last plotted quantity
self.last_image = im
return image
[docs]
def plot_channels(self,n=20, num=21, ncols=5, iv_min=None, iv_max=None, vmin=None, vmax=None, **kwargs):
"""
Plot the channels of the cube.
Args:
n (int): The number of channels to plot.
num (int): The figure number.
ncols (int): The number of columns in the plot.
iv_min (int): The minimum index of the velocity channel to plot.
iv_max (int): The maximum index of the velocity channel to plot.
vmin (float): The minimum velocity to plot.
vmax (float): The maximum velocity to plot.
**kwargs: Additional keyword arguments for the plot.
"""
if vmin is not None:
iv_min = np.abs(self.velocity - vmin).argmin()
if vmax is not None:
iv_max = np.abs(self.velocity - vmax).argmin()
if iv_min is None:
iv_min = 0
if iv_max is None:
iv_max = self.nv-1
nv = iv_max-iv_min
dv = nv/n
nrows = np.ceil(n / ncols).astype(int)
if (plt.fignum_exists(num)):
plt.figure(num)
plt.clf()
fig, axs = plt.subplots(ncols=ncols, nrows=nrows, figsize=(11, 2*nrows+1),num=num,clear=False)
for i, ax in enumerate(axs.flatten()):
if (i%ncols ==0):
no_ylabel=False
else:
no_ylabel=True
if (i>=ncols*(nrows-1)):
no_xlabel = False
else:
no_xlabel = True
self.plot(iv=int(iv_min+i*dv), ax=ax, no_xlabel=no_xlabel, no_ylabel=no_ylabel, **kwargs)
plt.show()
return
[docs]
def get_line_profile(self,threshold=None, **kwargs):
"""
Get the line profile of the cube.
Args:
threshold (float): The threshold for the line profile.
**kwargs: Additional keyword arguments for the plot.
"""
cube = self.image[:,:,:]
if threshold is not None:
cube = np.where(cube > threshold, cube, 0)
profile = np.nansum(cube, axis=(1,2)) / self._beam_area_pix()
return profile
[docs]
def plot_line(self,x_axis="velocity", threshold=None, ax=None, **kwargs):
"""
Plot the line profile of the cube.
Args:
x_axis (str): The axis to plot on.
threshold (float): The threshold for the line profile.
ax (matplotlib.axes.Axes): The axes to plot on.
**kwargs: Additional keyword arguments for the plot.
"""
if ax is None:
ax = plt.gca()
if x_axis == "channel":
x = np.arange(self.nv)
elif x_axis == "freq":
x = self.nu
else:
x = self.velocity
p = self.get_line_profile(threshold=threshold)
ax.plot(x, p, **kwargs)
dv = np.abs(self.velocity[2]-self.velocity[1])
print("Integrated line flux =", (np.sum(p) - 0.5*(p[0]+p[-1]))*dv)
return
# -- computing various "moments"
[docs]
def get_moment_map(self, moment=0, v0=0, M0_threshold=None, M8_threshold=None, threshold=None, iv_support=None, v_minmax = None):
"""
Calculate the moment map of the cube.
We use the same convention as CASA : moment 8 is peak flux, moment 9 is peak velocity
This returns the moment maps in physical units, ie:
- M0 is the integrated line flux (Jy/beam . km/s)
- M1 is the average velocity (km/s)
- M2 is the velocity dispersion (km/s)
- M8 is the peak intensity
- M9 is the velocity of the peak
Args:
moment (int): The moment to calculate.
v0 (float): The central velocity.
M0_threshold (float): The threshold for the M0 moment map.
M8_threshold (float): The threshold for the M8 moment map.
threshold (float): The threshold for the moment map.
iv_support (list): The indices of the velocity channels to support.
v_minmax (list): The minimum and maximum velocities to plot.
"""
if v0 is None:
v0 = 0
cube = np.copy(self.image)
dv = (self.velocity[1] - self.velocity[0])
v = self.velocity - v0
if threshold is not None:
cube = np.where(cube > threshold, cube, np.nan)
if v_minmax is not None:
vmin = np.min(v_minmax)
vmax = np.max(v_minmax)
iv_support = np.array(np.where(np.logical_and((self.velocity > vmin),(self.velocity < vmax)))).ravel()
print("Selecting channels:", iv_support)
if iv_support is not None:
v = v[iv_support]
cube = cube[iv_support,:,:]
M0 = np.nansum(cube, axis=0) * dv
M8 = np.max(cube, axis=0)
if moment in [1, 2]:
M1 = np.nansum(cube[:, :, :] * v[:, np.newaxis, np.newaxis], axis=0) * dv / M0
if moment == 0:
M=M0
if moment == 1:
M=M1 + v0
if moment == 2:
# avoid division by 0 or neg values in sqrt
thr = np.nanpercentile(M0[np.where(M0>0)],0.01)
M0[np.where(M0<thr)]=np.nan
M = np.sqrt(np.nansum(np.power(cube[:, :, :] * (v[:, np.newaxis, np.newaxis] - M1[np.newaxis, :, :]),2), axis=0) * dv / M0 )
if moment == 8:
M = M8
if moment == 9:
M = v[0] + dv * np.argmax(cube, axis=0)
if M0_threshold is not None:
M = np.ma.masked_where(M0 < M0_threshold, M)
if M8_threshold is not None:
M = np.ma.masked_where(M8 < M8_threshold, M)
return M
[docs]
def get_high_pass_filter_map(self, moment=None, w0=None, gamma=0, **kwargs):
"""
Perform a Gaussian high pass filter on a moment map.
w0 is in arcsec, with an optional radial stretch exponent gamma
Args:
moment (int): The moment to use for the high pass filter.
w0 (float): The width of the Gaussian filter in arcsec.
gamma (float): The radial stretch exponent.
**kwargs: Additional keyword arguments for the moment map.
"""
from scipy.ndimage.filters import gaussian_filter
if w0 is None:
raise ValueError("Please provide w0 for filter in sarcsec")
# Original data
M = self.get_moment_map(moment=moment, **kwargs)
X, Y = self.xaxis.copy(), self.xaxis.copy()
if gamma != 0:
# Stretch the original grid by desired radial stretch
R = np.hypot(X, Y)
x = X * (R**(gamma))
y = Y * (R**(gamma))
# Prepare interp function of input data on original grid
f = interpolate.interp2d(X, Y, M, kind='cubic')
# Interpolate input data onto radially stretched grid
stretched_map = f(x, y)
else:
# No strectching
stretched_map = M
# Perform convolution on warped data and subtract result to get residual
background = gaussian_filter(stretched_map, sigma=w0/self.pixelscale)
highpass_residual = stretched_map - background
if gamma != 0:
# We reinterpolate back
f = interpolate.interp2d(x, y, highpass_residual, kind='cubic')
output_map = f(X,Y)
else:
output_map = highpass_residual
return output_map
[docs]
def get_fwhm(self, v0=0, M0_threshold=None):
"""
Get the line FWHM of the cube.
Args:
v0 (float): The central velocity.
M0_threshold (float): The threshold for the M0 moment map.
"""
M2 = get_moment_map(self, moment=2, v0=v0, M0_threshold=M0_threshold)
return np.sqrt(8*np.log(2)) * M2
[docs]
def get_vturb(self, v0=0, M0_threshold=None, threshold=None, mol_weight=None):
"""
Get the turbulent linewidth of the cube.
Args:
v0 (float): The central velocity.
M0_threshold (float): The threshold for the M0 moment map.
"""
if mol_weight is None:
raise ValueError("mol_weight needs to be provided")
M2 = self.get_moment_map(moment=2, v0=v0, M0_threshold=M0_threshold, threshold=threshold)
Tb = self._Jybeam_to_Tb(self.get_moment_map(moment=8, v0=v0, M0_threshold=M0_threshold, threshold=threshold))
mH = 1.007825032231/sc.N_A
cs2 = sc.k * Tb / (mol_weight * mH)
return np.sqrt(8*np.log(2)* M2**2 - 2*cs2)
[docs]
def get_std(self,taper=0):
"""
Get the standard deviation of the cube.
Args:
taper (float): The taper for the beam.
"""
im1 =self.image[0,0:30,0:30]
im2 =self.image[-1,0:30,0:30]
# This is ugly copy and paste
# --- Convolution by taper
if taper < 1e-6: # 0 taper
taper = None
if taper is not None:
if taper < self.bmaj:
print("taper is smaller than bmaj=", self.bmaj)
delta_bmaj = self.pixelscale * FWHM_to_sigma
else:
delta_bmaj = np.sqrt(taper ** 2 - self.bmaj ** 2)
bmaj = taper
if taper < self.bmin:
print("taper is smaller than bmin=", self.bmin)
delta_bmin = self.pixelscale * FWHM_to_sigma
else:
delta_bmin = np.sqrt(taper ** 2 - self.bmin ** 2)
bmin = taper
sigma_x = delta_bmin / self.pixelscale * FWHM_to_sigma # in pixels
sigma_y = delta_bmaj / self.pixelscale * FWHM_to_sigma # in pixels
print("Original beam = ", self.bmaj, self.bmin)
print("tapper =", delta_bmaj, delta_bmin, self.bpa)
#print(
# "beam = ",
# np.sqrt(self.bmaj ** 2 + delta_bmaj ** 2),
# np.sqrt(self.bmin ** 2 + delta_bmin ** 2),
#)
self.bmaj = taper
self.bmin = taper
print("Updated beam = ", self.bmaj, self.bmin)
beam = Gaussian2DKernel(sigma_x, sigma_y, self.bpa * np.pi / 180)
im1 = convolve_fft(im1, beam)
im2 = convolve_fft(im2, beam)
# Correcting for beam area
im1 *= taper**2/(self.bmin * self.bmaj)
im2 *= taper**2/(self.bmin * self.bmaj)
# compute std deviation in image, assumes that no primary beam correction has been applied
self.std = np.nanstd([im1,im2])
print("Std=",self.std)
return
# -- Functions to deal the synthesized beam.
[docs]
def _beam_area(self):
"""Beam area in arcsec^2"""
return np.pi * self.bmaj * self.bmin / (4.0 * np.log(2.0))
[docs]
def _beam_area_str(self):
"""Beam area in steradian^2"""
return self._beam_area() * arcsec ** 2
[docs]
def _pixel_area(self):
"""Pixel area in arcsec^2"""
return self.pixelscale ** 2
[docs]
def _beam_area_pix(self):
"""Beam area in pix^2."""
return self._beam_area() / self._pixel_area()
@property
def beam(self):
"""Returns the beam parameters in ("), ("), (deg)."""
return self.bmaj, self.bmin, self.bpa
[docs]
def _Jybeam_to_Tb(self, im, RJ=False):
"""Convert flux converted from Jy/beam to K using full Planck law."""
im2 = np.nan_to_num(im)
nu = self.restfreq
if RJ:
Tb = abs(im2) * sc.c**2 / (2*nu**2*sc.k * 1e26 * self._beam_area_str())
else:
exp_m1 = 1e26 * self._beam_area_str() * 2.0 * sc.h * nu ** 3 / (sc.c ** 2 * abs(im2))
hnu_kT = np.log1p(exp_m1 + 1e-10)
Tb = sc.h * nu / (sc.k * hnu_kT)
return np.ma.where(im2 >= 0.0, Tb, -Tb)
[docs]
def make_cut(self, x0,y0,x1,y1,z=None,num=None):
"""
Make a cut in image 'z' along a line between (x0,y0) and (x1,y1)
x0, y0,x1,y1 are pixel coordinates
"""
if z is None:
z = self.image
if num is not None:
# Extract the values along the line, using cubic interpolation
x, y = np.linspace(x0, x1, num), np.linspace(y0, y1, num)
zi = ndimage.map_coordinates(z, np.vstack((y,x)))
else:
# Extract the valuees along the line at the pixel spacing
length = int(np.hypot(x1-x0, y1-y0))
x, y = np.linspace(x0, x1, length), np.linspace(y0, y1, length)
zi = z[y.astype(np.int), x.astype(np.int)]
return x,y,zi
[docs]
def explore(self,num=None):
"""
Quickly explore the cube.
- plot line profile
- plot peak brightness
- comvert frequency to line and molecule name
Args:
num (int): The figure number.
"""
# Quick preview function of a cube
filename = self.filename
k = filename.find('spw')
spw = filename[k:k+5]
freq = self.restfreq/1e9
# Storing line, molecule and object names
self.line = line_list[freq]
self.mol = self.line[0:self.line.find(" ")]
self.object = self.header['OBJECT']
# Making preview plot
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(20, 6), num=num, clear=False)
plt.subplots_adjust(hspace=0.01,wspace=0.01)
# line profile
ax = axes[0]
self.plot_line(ax=ax)
ax.text(0.02,0.95,self.object,
horizontalalignment='left',size=12,
color="black",
transform=ax.transAxes,
)
ax.text(0.02,0.9,self.line,
horizontalalignment='left',size=12,
color="black",
transform=ax.transAxes,
)
label=f"$\\nu=${freq:<10.6f}Ghz"
ax.text(0.98,0.95,label,
horizontalalignment='right',size=12,
color="black",
transform=ax.transAxes,
)
ax.text(0.98,0.9,spw,
horizontalalignment='right',size=12,
color="black",
transform=ax.transAxes,
)
# Moment 8
ax=axes[1]
self.plot(ax=ax,moment=8, Tb=True) #32, 23, 20
ax.text(0.02,0.95,"Peak brightness",
horizontalalignment='left',size=12,
color="white",
transform=ax.transAxes,
)
return
# spatalogue astropy api is currently broken
#Splatalogue.query_lines(f*(1-1e-6)*u.GHz, f*(1+1e-6)*u.GHz)
# doing it by hand for now
line_list = {
265.886431 : "HCN v=0 J=3-2" ,
265.759438 : "c-HCCCH v=0 4(4,1)-3(3,0)" ,
265.289650 : "CH3OH v t=0 6(1,5)-5(2,3)" ,
264.2701294 : "H2CO 10(1,9)-10(1,10)" ,
262.2569057 : "SO2 v=0 11(3,9)-11(2,10)" ,
262.064986 : "CCH v=0 N=3-2, J=5/2-3/2, F=3-2" ,
262.004260 : "CCH v=0 N=3-2, J=7/2-5/2, F=4-3" ,
261.843721 : "SO 3Σ v=0 7(6)-6(5)" ,
251.900495 : "CH3OH v t=0 4(3,2)-4(2,3) +-" ,
251.825770 : "SO 3Σ v=0 5(6)-4(5)" ,
251.2105851 : "SO2 v=0 8(3,5)-8(2,6)" ,
250.816954 : "NO J=5/2-3/2, Ω=1/2-, F3/2-1/2" ,
330.5879653 : "13CO J=3-2",
345.7959899 : "12CO J=3-2"
#249.000000 : "Continuum"
}
[docs]
def add_colorbar(mappable, shift=None, width=0.05, ax=None, trim_left=0, trim_right=0, side="right",**kwargs):
"""
Add a color bar to a plot.
Args:
mappable (matplotlib.cm.ScalarMappable): The mappable object to add the color bar to.
shift (float): The shift for the color bar.
width (float): The width of the color bar.
ax (matplotlib.axes.Axes): The axes to add the color bar to.
trim_left (float): The left trim for the color bar.
trim_right (float): The right trim for the color bar.
side (str): The side of the color bar.
"""
# creates a color bar that does not shrink the main plot or panel
# only works for horizontal bars so far
if ax is None:
try:
ax = mappable.axes
except:
ax = plt.gca()
# Get current figure dimensions
try:
fig = ax.figure
p = np.zeros([1,4])
p[0,:] = ax.get_position().get_points().flatten()
except:
fig = ax[0].figure
p = np.zeros([ax.size,4])
for k, a in enumerate(ax):
p[k,:] = a.get_position().get_points().flatten()
xmin = np.amin(p[:,0]) ; xmax = np.amax(p[:,2]) ; dx = xmax - xmin
ymin = np.amin(p[:,1]) ; ymax = np.amax(p[:,3]) ; dy = ymax - ymin
if side=="top":
if shift is None:
shift = 0.2
cax = fig.add_axes([xmin + trim_left, ymax + shift * dy, dx - trim_left - trim_right, width * dy])
cax.xaxis.set_ticks_position('top')
return fig.colorbar(mappable, cax=cax, orientation="horizontal",**kwargs)
elif side=="right":
if shift is None:
shift = 0.05
cax = fig.add_axes([xmax + shift*dx, ymin, width * dx, dy])
cax.xaxis.set_ticks_position('top')
return fig.colorbar(mappable, cax=cax, orientation="vertical",**kwargs)