Source code for ngcasa.flagging.summary

#  CASA Next Generation Infrastructure
#  Copyright (C) 2021 AUI, Inc. Washington DC, USA
#  Copyright (C) 2021 European Southern Observatory, ALMA partnership
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#
#  You should have received a copy of the GNU General Public License
#  along with this program.  If not, see <https://www.gnu.org/licenses/>.
import dask
import xarray as xr
from ._flagging_utils._summary_utils import _pol_id_to_corr_type_name, _cast_int


[docs]def summary(mxds, xds_idx, flag_varname='FLAG'): """ Produces a summary with counts of flags from a flag variable of an xarray dataset.The summary is returned as a dictionary with the same structure as the dictionary returned by CASA6/flagdata in summary mode. It includes total counts and partial counts broken down by array, observation, correlation, field, scan, and antenna. Note: per-SPW counts are not included in this prototype version. Handling of multi-SPWs operations is subject to changes depending on interface design at (or around) the ngCASA layer. Parameters ---------- mxds: xarray.core.dataset.Dataset Dataset of xarray datasets xds_idx: int Index of the xarray datasets to get counts from (index in the xds'i' attributes of mxds). This is an oversimplification (early prototyping) flag_varname: str Name of the flag variable to summarize TBD - additional counts like 'spwchan', 'spwcorr', etc. TBD - params to filter percentages in report (minrel, minabs, etc.) TBD - params related to other counts or scores (possibly other functions) Returns: ------- dict Dictionary with counts of flags set and total flags, for the whole dataset also broken down (grouped by) by multiple criteria (scan, field, antenna, correlation, etc.) """ # mxds is required to grab meta-information (names of fields, antennas and # correlations). But anyway for now, working only on 1 SPW (given by index) xds = mxds.attrs['xds' + '{}'.format(xds_idx)] return _summary_groupby(mxds, xds, flag_varname)
def _summary_groupby(mxds, xds, flag_varname): """ An implementation of flagging.summary based mostly on xarray groupby Assumes: xds['presence_baseline'] = xds.DATA.notnull().any(['chan', 'pol']) To use presence_baseline.sum() as count of 'total' """ # Sum up chan dimension, which is not relevant for any of the following: # FIELD, OBS, ANT, etc. flag_var = xds[flag_varname].sum(['chan']).where(xds.presence_baseline > 0, other=0) pres_var = xds['presence_baseline'] # For joint grouping of FLAG data var and presence_baseline # From FLAG vars we'll get the 'flagged' counters. From presence_baseline we # get the 'total' counters. count_xds = xr.Dataset({'flag_var': flag_var, 'presence_baseline': pres_var}) result = {} _count_pol_and_totals(result, xds, count_xds) count_xds['flag_var'] = count_xds.flag_var.sum(['pol']) _count_array_scan_obs(result, xds, count_xds) _count_field(result, xds, count_xds, mxds.fields.values) _count_antenna(result, xds, count_xds, mxds.ANTENNA.NAME.values) res_comp = dask.compute(result) return _cast_int(res_comp[0]) def _count_pol_and_totals(result, xds, count_xds): # Count by pol/correlation + grand totals grand_total = 0 grand_flagged = 0 result['correlation'] = {} flagged_corr = count_xds.flag_var.sum(['time', 'baseline']) # total_corr = count_xds.presence_baseline.sum(['time', 'baseline']) # sum_presence = count_xds.presence_baseline.sum(['time', 'baseline']) total_corr = xds.num_chan * count_xds.presence_baseline.sum() for pol, flagged in zip(flagged_corr.coords['pol'].values, flagged_corr.values): corr_str = _pol_id_to_corr_type_name(pol) result['correlation']['{}'.format(corr_str)] = {'flagged': flagged, 'total': total_corr} grand_flagged += flagged grand_total += total_corr result['total'] = grand_total result['flagged'] = grand_flagged def _count_array_scan_obs(result, xds, count_xds): """ Adds into result counts by data variables that can be done with the same straightforward grouping (ARRAY_ID, SCAN_NUMBER, OBSERVATION_ID) """ grouping_names_vars = [('array', xds.ARRAY_ID), ('scan', xds.SCAN_NUMBER), ('observation', xds.OBSERVATION_ID)] nchan, ncorr = xds.num_chan, xds.num_corr for name, data_var in grouping_names_vars: result[name] = {} groups = count_xds.groupby(data_var) for grp_lbl, grp_ds in groups: if grp_lbl < 0: continue flagged = grp_ds.flag_var.sum() # ? total = nchan * grp_ds.count() <- that would be total non-nan # only 'after last appplyflags' total = grp_ds.presence_baseline.sum() * nchan * ncorr group_str = '{}'.format(grp_lbl) result[name][group_str] = {'flagged': flagged, 'total': total} def _count_field(result, xds, count_xds, field_names): """ Adds into result counts by field. It needs to count by field name, but grouping s done by FIELD_ID. Multiple field IDs can map to a same name (object, target) """ result['field'] = {} nchan, ncorr = xds.num_chan, xds.num_corr groups = count_xds.groupby(xds.FIELD_ID) for grp_idx, grp_ds in groups: if grp_idx < 0: continue flagged = grp_ds.flag_var.sum() # ? total = nchan * grp_ds.count() <- total non-nan 'after last apply' total = grp_ds.presence_baseline.sum() * nchan * ncorr field_str = '{}'.format(field_names[grp_idx]) if field_str in result['field']: result['field'][field_str]['flagged'] += flagged result['field'][field_str]['total'] += total else: result['field'][field_str] = {'flagged': flagged, 'total': total} def _count_antenna(result, xds, count_xds, ant_names): """ Adds into result counts by antenna. An antenna name/index can be ANTENNA1 or ANTENNA2 """ result['antenna'] = {} nchan, ncorr = xds.num_chan, xds.num_corr groups_ant1 = count_xds.groupby(xds.ANTENNA1) groups_ant2 = count_xds.groupby(xds.ANTENNA2) for groups in [groups_ant1, groups_ant2]: for grp_idx, grp_ds in groups: if grp_idx < 0: continue flagged = grp_ds.flag_var.sum() # ? total = nchan * grp_ds.flag_var.count() # <- total non-nan 'after last applyflags' total = grp_ds.presence_baseline.sum() * nchan * ncorr ant_str = '{}'.format(ant_names[grp_idx]) if ant_str in result['antenna']: result['antenna'][ant_str]['flagged'] += flagged result['antenna'][ant_str]['total'] += total else: result['antenna'][ant_str] = {'flagged': flagged, 'total': total}