Source code for ngcasa.deconvolution.deconvolve_point_clean

#  CASA Next Generation Infrastructure
#  Copyright (C) 2021 AUI, Inc. Washington DC, USA
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License
#  along with this program.  If not, see <https://www.gnu.org/licenses/>.
"""
this module will be included in the api
"""

import numpy as np
from numba import jit
import numba as nb
import itertools
from copy import deepcopy
import time

def _ndim_list(shape):
    return [_ndim_list(shape[1:]) if len(shape) > 1 else None for _ in range(shape[0])]

[docs]def deconvolve_point_clean(img_xds, deconvolve_parms, sel_parms): """ .. todo:: This function is not yet implemented An iterative solver to construct a model from an observed image(set) and psf(set). Sky model : Point source Algorithm : CLEAN (a greedy algorithm for chi-square minimization) Options : Hogbom, Clark Input : Requires an input cube (mfs is a cube with nchan=1) Output : Cube model image Returns ------- img_dataset : xarray.core.dataset.Dataset """ print('######################### Start deconvolve_point_clean #########################') from numba import jit import time import math import dask.array.fft as dafft import xarray as xr import dask.array as da import matplotlib.pylab as plt import dask import copy, os from numcodecs import Blosc from itertools import cycle from cngi._utils._check_parms import _check_sel_parms, _check_existence_sel_parms #print('****',sel_parms,'****') _img_xds = img_xds.copy(deep=True) _sel_parms = copy.deepcopy(sel_parms) _deconvolve_parms = copy.deepcopy(deconvolve_parms) ##############Parameter Checking and Set Defaults############## #Check img data_group #_check_sel_parms(_img_xds,_sel_parms,new_or_modified_data_variables={'psf_sum_weight':'SUM_WEIGHT','psf':'PSF'},append_to_in_id=True) ################################################################################## #chan_chunk_size = img_xds[_sel_parms["residual"]].chunks[2][0] #freq_chan = da.from_array(vis_dataset.coords['chan'].values, chunks=(chan_chunk_size)) n_chunks_in_each_dim = _img_xds[sel_parms['data_group_in']["image"]].data.numblocks chunk_sizes = _img_xds[sel_parms['data_group_in']["image"]].chunks image_shape = _img_xds[sel_parms['data_group_in']["image"]].shape #print('n_chunks_in_each_dim',n_chunks_in_each_dim) # l,m,time,chan,pol #only allow chunking on time and chan iter_chunks_indx = itertools.product(np.arange(1),np.arange(1),np.arange(n_chunks_in_each_dim[2]), np.arange(n_chunks_in_each_dim[3]),np.arange(1)) model_list = _ndim_list((1,1,n_chunks_in_each_dim[2],n_chunks_in_each_dim[3],1)) residual_list = _ndim_list((1,1,n_chunks_in_each_dim[2],n_chunks_in_each_dim[3],1)) #clean_list = _ndim_list((1,1,n_chunks_in_each_dim[2],n_chunks_in_each_dim[3],1)) #c_l, c_m, c_pol chunking is ignored for c_l, c_m, c_time, c_chan, c_pol in iter_chunks_indx: #print(c_time,c_chan) #There are two diffrent gridder wrapped functions _standard_grid_psf_numpy_wrap and _standard_grid_numpy_wrap. #This is done to simplify the psf and weight gridding graphs so that the vis_dataset is not loaded. deconv_graph_node = dask.delayed(_clean_wrap)( img_xds[_sel_parms['data_group_in']["image"]].data.partitions[:, :, c_time, c_chan, :], img_xds[_sel_parms['data_group_in']["psf"]].data.partitions[:, :, c_time, c_chan, :], dask.delayed(deconvolve_parms)) model_list[c_l][c_m][c_time][c_chan][c_pol] = da.from_delayed(deconv_graph_node[0],(image_shape[0],image_shape[1],chunk_sizes[2][c_time],chunk_sizes[3][c_chan],image_shape[4]),dtype=np.double) residual_list[c_l][c_m][c_time][c_chan][c_pol] = da.from_delayed(deconv_graph_node[1],(image_shape[0],image_shape[1],chunk_sizes[2][c_time],chunk_sizes[3][c_chan],image_shape[4]),dtype=np.double) model = da.block(model_list) residual = da.block(residual_list) _img_xds[sel_parms['data_group_out']["residual"]] = xr.DataArray(residual, dims=['l','m','time','chan','pol']) _img_xds[sel_parms['data_group_out']["model"]] = xr.DataArray(model, dims=['l','m','time','chan','pol']) print('######################### Created graph for deconvolve_point_clean #########################') return _img_xds
def _clean_wrap(dirty, psf, deconvolve_parms): """ gamma, threshold, niter Performs Hogbom Clean on the ``dirty`` image given the ``psf``. Parameters ---------- dirty : np.ndarray float64 dirty image of shape (ny, nx) psf : np.ndarray float64 Point Spread Function of shape (2*ny, 2*nx) gamma (optional) float the gain factor (must be less than one) threshold (optional) : float or str the threshold to clean to niter (optional : integer the maximum number of iterations allowed Returns ------- np.ndarray float64 clean image of shape (ny, nx) np.ndarray float64 residual image of shape (ny, nx) """ # l,m,time,chan,pol # deep copy dirties to first residuals, # want to keep the original dirty maps residual = deepcopy(dirty) #print(residual.shape) model = np.zeros(residual.shape) threshold = deconvolve_parms['threshold'] niter = deconvolve_parms['n_iter'] gain = deconvolve_parms['gain'] image_shape = dirty.shape start_time = time.time() if deconvolve_parms['decon_kernel'] == 0: _clean_jit(residual, model, psf, gain, threshold,niter) else: _clean_jit_vec(residual, model, psf, gain, threshold,niter) print('Time ', time.time()-start_time) return model, residual @jit(nopython=True,nogil=True,cache=True) def _clean_jit(residual, model, psf, gain, threshold,niter): peak_pos = np.zeros((2,),dtype=nb.u4) #peak_pos = np.zeros((2,),dtype=int) psf_shape = np.array(psf.shape) psf_center = psf_shape//2 res_shape = np.array(residual.shape) res_center = res_shape//2 #print(psf_shape,psf_center) #print(res_shape,res_center) for i_time in range(res_shape[2]): for i_chan in range(res_shape[3]): for i_pol in range(res_shape[4]): i = 0 _abs_image_peaks(residual[:,:,i_time,i_chan,i_pol],peak_pos) peak = residual[peak_pos[0],peak_pos[1],i_time,i_chan,i_pol] if np.isnan(peak) or (peak==0.0): i = niter peak_abs = np.abs(peak) scaled_threshold = threshold*peak_abs #print(scaled_threshold,peak_abs) while peak_abs > scaled_threshold and i < niter: model[peak_pos[0],peak_pos[1],i_time,i_chan,i_pol] += gain*peak res_start_indx_x = peak_pos[0] - psf_center[0] if res_start_indx_x < 0: res_start_indx_x = 0 res_start_indx_y = peak_pos[1] - psf_center[1] if res_start_indx_y < 0: res_start_indx_y = 0 res_end_indx_x = peak_pos[0] + (psf_shape[0] - psf_center[0]) #Is actauly end index plus 1, because np.arange is [...) if res_end_indx_x >= res_shape[0]: res_end_indx_x = res_shape[0] res_end_indx_y = peak_pos[1] + (psf_shape[1] - psf_center[1]) if res_end_indx_y >= res_shape[1]: res_end_indx_y = res_shape[1] #print('peak_pos',peak_pos) #print('res_start_indx_x,res_end_indx_x',res_start_indx_x,res_end_indx_x) #print('res_start_indx_y,res_end_indx_y',res_start_indx_y,res_end_indx_y) for i_x in np.arange(res_start_indx_x,res_end_indx_x): for i_y in np.arange(res_start_indx_y,res_end_indx_y): psf_i_x = psf_center[0] - (peak_pos[0] - res_start_indx_x) + (i_x - res_start_indx_x) psf_i_y = psf_center[1] - (peak_pos[1] - res_start_indx_y) + (i_y - res_start_indx_y) residual[i_x,i_y,i_time,i_chan,i_pol] -= gain*psf[psf_i_x,psf_i_y,i_time,i_chan,i_pol] _abs_image_peaks(residual[:,:,i_time,i_chan,i_pol],peak_pos) peak = residual[peak_pos[0],peak_pos[1],i_time,i_chan,i_pol] if np.isnan(peak) or (peak==0.0): i = niter peak_abs = np.abs(peak) i += 1 #print(i_time,i_chan,i_pol,scaled_threshold,peak_abs,peak_pos,i,niter,gain) #Partially vectorized def _clean_jit_vec(residual, model, psf, gain, threshold,niter): peak_pos = np.zeros((2,),dtype=int) psf_shape = np.array(psf.shape) psf_center = psf_shape//2 res_shape = np.array(residual.shape) res_center = res_shape//2 #print(psf_shape,psf_center) #print(res_shape,res_center) for i_time in range(res_shape[2]): for i_chan in range(res_shape[3]): for i_pol in range(res_shape[4]): i = 0 peak_pos = np.array(np.unravel_index(np.nanargmax(np.abs(residual[:,:,i_time,i_chan,i_pol])),res_shape[0:2])) peak = residual[peak_pos[0],peak_pos[1],i_time,i_chan,i_pol] if np.isnan(peak) or (peak==0.0): i = niter peak_abs = np.abs(peak) scaled_threshold = threshold*peak_abs #print(scaled_threshold,peak_abs) while peak_abs > scaled_threshold and i < niter: model[peak_pos[0],peak_pos[1],i_time,i_chan,i_pol] += gain*peak res_start_indx_x = peak_pos[0] - psf_center[0] if res_start_indx_x < 0: res_start_indx_x = 0 res_start_indx_y = peak_pos[1] - psf_center[1] if res_start_indx_y < 0: res_start_indx_y = 0 res_end_indx_x = peak_pos[0] + (psf_shape[0] - psf_center[0]) #Is actauly end index plus 1, because np.arange is [...) if res_end_indx_x >= res_shape[0]: res_end_indx_x = res_shape[0] res_end_indx_y = peak_pos[1] + (psf_shape[1] - psf_center[1]) if res_end_indx_y >= res_shape[1]: res_end_indx_y = res_shape[1] psf_start_indx_x = psf_center[0] - (peak_pos[0] - res_start_indx_x) psf_start_indx_y = psf_center[1] - (peak_pos[1] - res_start_indx_y) psf_end_indx_x = psf_center[0] + (res_end_indx_x - peak_pos[0]) psf_end_indx_y = psf_center[1] + (res_end_indx_y - peak_pos[1]) residual[res_start_indx_x:res_end_indx_x,res_start_indx_y:res_end_indx_y,i_time,i_chan,i_pol] -= gain*psf[psf_start_indx_x:psf_end_indx_x,psf_start_indx_y:psf_end_indx_y,i_time,i_chan,i_pol] peak_pos = np.array(np.unravel_index(np.nanargmax(np.abs(residual[:,:,i_time,i_chan,i_pol])),res_shape[0:2])) peak = residual[peak_pos[0],peak_pos[1],i_time,i_chan,i_pol] if np.isnan(peak) or (peak==0.0): i = niter peak_abs = np.abs(peak) i += 1 #print(i_time,i_chan,i_pol,scaled_threshold,peak_abs,peak_pos,i,niter,gain) @jit(nopython=True,nogil=True,cache=True) def _abs_image_peaks(image,peak_pos): #min = image[0,0] max = 0.0 #min_x = 0 #min_y = 0 max_x = 0 max_y = 0 for i_x in range(image.shape[0]): for i_y in range(image.shape[1]): if np.abs(image[i_x,i_y]) > max: max = np.abs(image[i_x,i_y]) peak_pos[0] = i_x peak_pos[1] = i_y # if image[i_x,i_y] < min: # min = image[i_x,i_y] # peak_pos[2] = i_x # peak_pos[3] = i_y