"""
Functions for cross-matching AGASC and Gaia
"""
import tqdm
import itertools
from astropy.table import Table, join, MaskedColumn, vstack
import numpy as np
from agasc_gaia import utils, datasets
from astropy.coordinates import SkyCoord
from astropy import units as u
from pathlib import Path
import agasc_gaia.gaia_queries as gaia_queries
import scipy
import networkx as nx
from agasc_gaia import config
XMATCH_DTYPE = [
    ("agasc_id", np.int64),
    ("gaia_id", np.int64),
    ("ave_gaia_mag", np.float16),  # this could be removed and calculated downstream
    ("has_mag", np.int16),
    ("g_mag", np.float16),
    ("g_mag_error", np.float16),
    ("rp_mag", np.float16),
    ("rp_mag_error", np.float16),
    ("bp_mag", np.float16),
    ("bp_mag_error", np.float16),
    ("ra", np.float64),
    ("dec", np.float64),
    ("ra_error", np.float64),
    ("dec_error", np.float64),
    ("epoch", np.float16),
    ("pm_dec", np.float16),
    ("pm_ra", np.float16),
    ("phot_variable_flag", bool),
    ("mad_mag_g_fov", np.float16),
    # the following are not used downstream, are used only in the direct cross-match
    ("d2d", np.float16),
    ("mag_pred", np.float16),
    ("d_mag", np.float16),
    ("p_match", np.float32),
    ("log_p_match", np.float16),
    ("p_relative", np.float32),
    ("log_p_relative", np.float16),
    ("p_value", np.float32),
    ("log_p_value", np.float16),
    ("best_match", bool),
    # the following are not used downstream, were used only using the Tycho2/GSC2.3 cross-match
    # flag to denote the brightest star among the match candidates
    ("brightest_neighbor", bool),
    # the magnitude of the brightest match candidate
    ("min_neighbor_mag", np.float16),
    # the ratio of the intensities of the brightest match candidate and each star
    ("neighbor_mag_weight", np.float16),
    # number of potentially unresolved neighbors that significantly contribute to the magnitude
    ("n_mag_neighbors", int),
    # the following come from the agasc summary, so they are added later if needed
    # ('x_val_sample', np.int16),
    # ('mag_aca', np.float16),
    # ('mag', np.float16),
    # ('tycho_id', '<U12'),
    # ('tycho_tdsc_id', '<U12'),
    # ('gsc2.3', '<U10'),
    # ('guide', bool),
    # ('acq', bool),
    # ('mag_aca_err', np.float16),
    # ('mag_aca_obs', np.float16),      # used when fitting
    # ('mag_aca_err_obs', np.float16),  # used when fitting
    # ('mag_catid', np.float16),
    # ('mag_band', np.float16),
]
"""
The result of the following query was saved in `data/agasc_gsc23_gaiadr3-result.fits.gz`:
```
SELECT
a.agasc_id, a.ra as ra_agasc, a.dec as dec_agasc, a.epoch as epoch_agasc, a.pm as pm_agasc,
g.source_id as gaia_id, g.ra, g.dec, g.pmra as pm_ra, g.pmdec as pm_dec, g.ref_epoch,
g.random_index,
g.phot_g_mean_mag as g_mag, g.phot_bp_mean_mag as bp_mag, g.phot_rp_mean_mag as rp_mag,
1./phot_g_mean_flux_over_error as g_mag_error,
1./phot_bp_mean_flux_over_error as bp_mag_error,
1./phot_rp_mean_flux_over_error as rp_mag_error,
g.phot_variable_flag, g.phot_proc_mode, g.non_single_star,
v.range_mag_g_fov, v.std_dev_mag_g_fov, v.mad_mag_g_fov,
bn.original_ext_source_id as gsc2_3, bn.angular_distance, bn.number_of_neighbours
FROM agasc AS a
JOIN gaiadr3.gsc23_best_neighbour as bn ON a.gsc2_3 = bn.original_ext_source_id
JOIN gaiadr3.gaia_source as g ON g.source_id = bn.source_id
LEFT JOIN gaiadr3.vari_summary as v on g.source_id = v.source_id
```
The result of the following query was saved in `data/agasc_tycho_gaiadr3-result.fits.gz`
(note the first left join, to potentially add TDSC data):
```
SELECT
a.agasc_id, a.ra as ra_agasc, a.dec as dec_agasc, a.epoch as epoch_agasc, a.pm as pm_agasc,
g.source_id as gaia_id, g.ra, g.dec, g.pmra as pm_ra, g.pmdec as pm_dec, g.ref_epoch,
g.random_index,
g.phot_g_mean_mag as g_mag, g.phot_bp_mean_mag as bp_mag, g.phot_rp_mean_mag as rp_mag,
1./phot_g_mean_flux_over_error as g_mag_error,
1./phot_bp_mean_flux_over_error as bp_mag_error,
1./phot_rp_mean_flux_over_error as rp_mag_error,
g.phot_variable_flag, g.phot_proc_mode, g.non_single_star,
v.range_mag_g_fov, v.std_dev_mag_g_fov, v.mad_mag_g_fov,
bn.original_ext_source_id as tycho_id, bn.angular_distance, bn.number_of_neighbours
FROM
gaiadr3.tycho2tdsc_merge_best_neighbour as bn
LEFT JOIN agasc AS a ON a.tycho_id = bn.original_ext_source_id
JOIN gaiadr3.gaia_source as g ON g.source_id = bn.source_id
LEFT JOIN gaiadr3.vari_summary as v on g.source_id = v.source_id
```
"""
@utils.cached(name="agasc-gsc-tycho-gaia-x-match-all")
def get_agasc_tycho_gsc_gaia_x_match_all(
    gsc23="agasc_gsc23_gaiadr3-result",
    tycho2="agasc_tycho2tdsc_gaiadr3-result",
):
    agasc_gsc_gaia_file = Path(config.FILES[gsc23])
    agasc_tycho2_gaia_file = Path(config.FILES[tycho2])
    assert agasc_gsc_gaia_file.exists()
    assert agasc_tycho2_gaia_file.exists()
    agasc_all = datasets.get_agasc_summary()
    agasc_all = agasc_all[agasc_all["class"] == 0]
    # GSC 2.3
    print("Reading GSC2.3-Gaia data")
    data_gsc = Table.read(agasc_gsc_gaia_file)
    data_gsc.remove_column(
        "gsc2_3"
    )  # gsc2.3 will be added from the summary so it will be duplicated
    data_gsc.rename_columns(["ref_epoch"], ["epoch"])
    data_gsc = join(
        agasc_all[
            [
                "agasc_id",
                "tycho_id",
                "tycho_tdsc_id",
                "in_tdsc",
                "gsc2.3",
                "mag_diff_gsc2",
                "separation_gsc2",
            ]
        ],
        data_gsc,
        keys=["agasc_id"],
        join_type="inner",
        table_names=["agasc", "gaia"],
    )
    data_gsc = data_gsc[
        (data_gsc["mag_diff_gsc2"] < 4) & (data_gsc["separation_gsc2"] < 4)
    ]
    # Tycho-2
    print("Reading Tycho2-Gaia data")
    data_tycho_update = Table.read(agasc_tycho2_gaia_file)
    data_tycho_update.rename_columns(["ref_epoch"], ["epoch"])
    # these should have been removed, so I am just checking
    bad = [
        "2673-3845-1",
        "2673-3845-2",
        "2673-3845-3",
        "6461-1120-1",
        "6461-1120-2",
        "6461-1120-3",
    ]
    assert len(data_tycho_update[np.in1d(data_tycho_update["tycho_id"], bad)]) == 0
    data_tycho_update.rename_column("tycho_id", "tycho_tdsc_id")
    print("Joining Tycho2-GSC2.3-Gaia data")
    data_tycho_update = join(
        agasc_all[
            [
                "agasc_id",
                "tycho_id",
                "tycho_tdsc_id",
                "in_tdsc",
                "gsc2.3",
                "separation_tdsc",
                "mag_diff_tdsc",
            ]
        ][~agasc_all["tycho_tdsc_id"].mask],
        data_tycho_update[~data_tycho_update["agasc_id"].mask],
        join_type="inner",
        keys=["agasc_id", "tycho_tdsc_id"],
        table_names=["agasc", "gaia"],
    )
    # the union of the two
    cols = [col for col in data_gsc.colnames if col in data_tycho_update.colnames]
    data = vstack(
        [
            data_tycho_update[cols + ["separation_tdsc", "mag_diff_tdsc"]],
            data_gsc[~np.in1d(data_gsc["agasc_id"], data_tycho_update["agasc_id"])][
                cols + ["mag_diff_gsc2", "separation_gsc2"]
            ],
        ],
        metadata_conflicts="silent",
    )
    data = do_before(data)
    data = do_neighbors(data)
    data = do_after(data)
    return data
@utils.cached(name="agasc-gsc-tycho-gaia-x-match")
def get_agasc_tycho_gsc_gaia_x_match():
    data = get_agasc_tycho_gsc_gaia_x_match_all()
    # Remove AGASC IDs that appear more than once. We will not deal with that. This happens
    # if they are matched to more than one Gaia counterpart, which can happen if the star in AGASC
    # is actually more than one unresolved stars. It can also happen if the star is not there
    aid_groups = data.group_by("agasc_id")
    n = np.diff(aid_groups.groups.indices)
    data = aid_groups[
        aid_groups.groups.indices[:-1][n == 1]
    ]  # selects only groups with size == 1
    # this eliminates stars that have more than one Gaia counterpart with large contribution
    # to the magnitude
    data = data[data["neighbor_mag_weight"] > 0.1]
    return data
def d_mag_probability(d_mag):
    amplitude = 1.0244744572765188
    nu = 1.9565734848257785
    x0 = -0.005441123244800838
    dx = 0.16857590603333622
    p_mag = amplitude * scipy.stats.t.pdf(d_mag, nu, loc=x0, scale=dx)
    return p_mag
def gaussian_d2d_probability_(d2d, sigma_d2d=1.5):
    p_d2d = np.exp(-0.5 * (d2d / sigma_d2d) ** 2) / (np.sqrt(2 * np.pi) * sigma_d2d)
    return p_d2d
def d2d_probability(d2d):
    p, p1, p2, p3 = 0.89020772, 0.09891197, 0.00791296, 0.00296736
    p_d2d = (
        p * gaussian_d2d_probability_(d2d, sigma_d2d=0.25)
        + p1 * gaussian_d2d_probability_(d2d, sigma_d2d=0.5)
        + p2 * gaussian_d2d_probability_(d2d, sigma_d2d=0.8)
        + p3 * gaussian_d2d_probability_(d2d, sigma_d2d=1.2)
    )
    return p_d2d
def agasc_gaia_match_probability(d_mag, d2d):
    return d_mag_probability(d_mag) * d2d_probability(d2d)
def agasc_gaia_match_probability_prelim(d_mag, d2d):
    return d_mag_probability(d_mag) * gaussian_d2d_probability_(d2d)
@utils.cached(name="agasc-gaia-x-match-all")
def get_agasc_gaia_x_match_all(gaia_result="agasc_gaia-result"):
    print("Reading AGASC and Gaia data")
    agasc_gaia_file = config.FILES[gaia_result]
    assert agasc_gaia_file.exists()
    agasc_gaia = Table.read(agasc_gaia_file)
    data = do_before(agasc_gaia)
    print("Finding the best matches")
    # match probability
    g = data[["agasc_id", "log_p_match"]].group_by("agasc_id")
    i = g.groups.indices[:-1] + np.array(
        [
            np.argmin(data["log_p_match"][i:j])
            for i, j in np.lib.stride_tricks.sliding_window_view(g.groups.indices, 2)
        ]
    )
    data["best_match"] = False
    data["best_match"][i] = True
    data = do_after(data)
    return data
[docs]@utils.cached(name="agasc-gaia-xmatch-difficult")
def get_agasc_gaia_x_match_difficult():
    """
    Redo cross-match for "difficult" stars.
    Difficult stars are AGASC stars that are matched to the same Gaia star.
    The procedure is to:
    1. Start with the original cross-match.
    2. Find a "collision graph" (sets of AGASC stars with the same Gaia star match).
    3. Recalculate the cross-matches for each of these sets.
    4. Determine if there still are collisions.
    5. If there are, repeat from step 2.
    Returns
    -------
    astropy.table.Table
        Table of corrected matches, with columns:
        agasc_id, gaia_id, best_match, d2d, d_mag, p_match, p_value, p_relative,
        idx,   # index into the input table
        mag, mag_band, mag_catid, pos_catid,
        best_match{i} for each iteration
    """
    print("get_agasc_gaia_x_match_difficult")
    agasc_summary = datasets.get_agasc_summary()
    agasc_gaia_x_match_all_orig = get_agasc_gaia_x_match_all()
    # take a subset of the columns to save memory and so the original is not modified
    agasc_gaia_x_match_all_orig = agasc_gaia_x_match_all_orig[
        "agasc_id",
        "gaia_id",
        "best_match",
        "d2d",
        "d_mag",
        "p_match",
        "p_value",
        "p_relative",
    ]
    # this index will be used later to join this table with subsets
    agasc_gaia_x_match_all_orig["idx"] = np.arange(len(agasc_gaia_x_match_all_orig))
    i = np.searchsorted(
        agasc_summary["agasc_id"], agasc_gaia_x_match_all_orig["agasc_id"]
    )
    assert np.all(
        agasc_summary["agasc_id"][i] == agasc_gaia_x_match_all_orig["agasc_id"]
    )
    agasc_gaia_x_match_all_orig["mag"] = agasc_summary["mag"][i]
    agasc_gaia_x_match_all_orig["mag_band"] = agasc_summary["mag_band"][i]
    agasc_gaia_x_match_all_orig["mag_catid"] = agasc_summary["mag_catid"][i]
    agasc_gaia_x_match_all_orig["pos_catid"] = agasc_summary["pos_catid"][i]
    agasc_gaia_x_match_all_orig["best_match_0"] = agasc_gaia_x_match_all_orig[
        "best_match"
    ]
    # print("Finding collisions")
    graphs = [get_collision_graph(agasc_gaia_x_match_all_orig)]
    results = []
    agasc_gaia_x_match_all_orig["latest_pos_cat"] = -1
    for _ in range(10):  # will try at most 10 times, but 2 should be enough
        # print("Fixing collisions")
        results.append(
            fix_collisions(agasc_gaia_x_match_all_orig, nx.compose_all(graphs))
        )
        iter = len(results)
        # update best_match in the original table and sanity check
        idx = np.searchsorted(agasc_gaia_x_match_all_orig["idx"], results[-1]["idx"])
        assert np.all(
            agasc_gaia_x_match_all_orig["agasc_id"][idx] == results[-1]["agasc_id"]
        )
        assert np.all(
            agasc_gaia_x_match_all_orig["gaia_id"][idx] == results[-1]["gaia_id"]
        )
        best_match = np.array(agasc_gaia_x_match_all_orig["best_match"])
        best_match[idx] = results[-1]["best_match"]
        agasc_gaia_x_match_all_orig[f"best_match_{iter}"] = best_match
        agasc_gaia_x_match_all_orig["best_match"] = best_match
        agasc_gaia_x_match_all_orig["latest_pos_cat"][idx] = results[-1][
            "latest_pos_cat"
        ]
        # print("Finding collisions")
        # Get the new collision graph.
        # New collisions can happen if we "fix" a collision by changing a cross-match,
        # and now set it to a Gaia star matched to another AGASC star with no previous collision
        graph_2 = get_collision_graph(agasc_gaia_x_match_all_orig)
        connected_components = [
            comp for comp in nx.connected_components(graph_2) if len(comp) > 1
        ]
        if len(connected_components) == 0:
            # we are good, no more collisions
            break
        else:
            # need to fix these new collisions too, so we _add_ them to the previous ones and repeat
            graphs.append(graph_2)
    print(f"get_agasc_gaia_x_match_difficult done in {iter} iterations")
    assert len(connected_components) == 0
    # Finally, add the `group` column to the result and discard groups with only one star
    graph = nx.compose_all(graphs)
    agasc_gaia_x_match_all_orig["group"] = get_group(agasc_gaia_x_match_all_orig, graph)
    difficult = agasc_gaia_x_match_all_orig[agasc_gaia_x_match_all_orig["group"] >= 0]
    return difficult 
[docs]def fix_collisions(table, collision_graph):
    """
    Find the best matches for a collection of "difficult" stars.
    This is called by get_agasc_gaia_x_match_difficult.
    Difficult stars are those that had a collision in the cross-match with Gaia (i.e. were
    cross-matched with the same Gaia star as another AGASC star). The collisions are represented by
    a collision graph. AGASC stars that are not in the collision graph are not considered.
    Parameters
    ----------
    table : astropy.table.Table
        Table of matches, as returned by agasc_gaia.cross_match.get_agasc_gaia_x_match_all()
    collision_graph : networkx.Graph
        Graph representing the collisions. Nodes are AGASC stars, edges represent collisions.
        The connected components of the graph are stars that should be cross-matched simultaneously.
    """
    # group table by connected component and keep only the groups with more than one star
    group = get_group(table, collision_graph, column="agasc_id")
    table_copy = table[group >= 0].copy()
    table_copy["group"] = group[group >= 0]
    grouped_table = table_copy.group_by("group")
    # for each group, find the latest catalog among the possible matches
    # this needs to be a sensible value (see comment below)
    pos_catid = np.asarray(grouped_table["pos_catid"])
    d2d = np.asarray(grouped_table["d2d"])
    latest_pos_cat = np.array(
        [
            get_latest_pos_cat(pos_catid[i:j][d2d[i:j] < 5])
            for i, j in zip(
                grouped_table.groups.indices[:-1], grouped_table.groups.indices[1:]
            )
        ]
    )
    # process each group separately to find the best matches
    # the following produces a list of indices into the grouped table
    # note that I can't use idx on the original table, because groupby might have sorted it
    idx = np.concatenate(
        [
            grouped_table.groups.indices[i]
            + get_best_matches(grouped_table.groups[i], latest_pos_cat[i])
            for i in range(len(grouped_table.groups))
        ]
    )
    best_match = np.zeros(len(grouped_table), dtype=bool)
    best_match[idx] = True
    result = Table()
    result["idx"] = grouped_table["idx"]
    result["agasc_id"] = grouped_table["agasc_id"]
    result["gaia_id"] = grouped_table["gaia_id"]
    result["best_match"] = best_match
    result["latest_pos_cat"] = np.concatenate(
        [
            np.repeat(get_latest_pos_cat(pos_catid[i:j][d2d[i:j] < 5]), j - i)
            for i, j in zip(
                grouped_table.groups.indices[:-1], grouped_table.groups.indices[1:]
            )
        ]
    )
    return result 
[docs]def get_best_matches(matches, pos_catid):
    """
    Find the best matches from a list of matches, filtering out based on POS_CATID.
    This is called by get_agasc_gaia_x_match_difficult and is intended to cleanup collisions,
    where multiple AGASC stars are matched to the same Gaia star.
    This function recomputes the cross-matches for the given `matches` table, discarding all but
    the AGASC stars with POS_CATID equal to the `pos_catid` argument.
    Parameters
    ----------
    matches : astropy.Table
        Table of matches, as returned by agasc_gaia.cross_match.get_agasc_gaia_x_match_all()
    pos_catid : int
        POS_CATID to use for filtering. Only matches with this POS_CATID will be considered.
    Returns
    -------
    np.array : indices of the best matches
    """
    # Discarding all but the latest POS_CATID can cause issues if there are high-PM stars
    # in the table.
    #
    # Before fixing the collisions, I discard all matches but the ones with AGASC stars with
    # POS_CATID from the latest catalog. The rationale being that there could be inconsistencies
    # within AGASC, where the same star was added from two catalogs, creating a duplicate.
    # High-PM star will be among the potential matches for many stars, and I am guessing high-PM
    # stars will be in recent catalogs too.
    # Let's consider the case of star A from an older catalog,
    # and star B a nearby high-PM star. Star B will be in recent catalogs, and its Gaia counterpart
    # will appear as a possible match for both A and B. Since B is in a recent catalog, that is the
    # one that will be kept, and the correct match for A will be discarded.
    # That means that `pos_catid` needs to be selected carefully. To begin with, one can't use all
    # possible matches. In `get_collision_graph`, we consider all possible matches within 5 arcsec
    # to determine `pos_catid`.
    matches = matches[["agasc_id", "gaia_id", "p_match", "pos_catid"]]
    i_sorted = np.argsort(
        matches[
            [
                "p_match",
            ]
        ]
    )[::-1]
    result = {}
    i_result = []
    for i in i_sorted:
        row = matches[i]
        if row["pos_catid"] != pos_catid:
            continue
        if (
            row["agasc_id"] not in result.keys()
            and row["gaia_id"] not in result.values()
        ):
            result[row["agasc_id"]] = row["gaia_id"]
            i_result.append(i)
    return np.array(i_result) 
[docs]def get_latest_pos_cat(cat):
    """
    Return the latest position catalog from a list of POS_CATIDs.
    This is called by get_agasc_gaia_x_match_difficult.
    """
    cats = [5, 6, 4, 3, 2, 1]
    if not np.all(np.in1d(cat, cats)):
        raise Exception(f'`cat` argument must be in: {", ".join(cats)}')
    for c in cats:
        if np.any(np.in1d(c, cat)):
            return c
    raise Exception(
        "Reached end of function without returning a value."
        "This only happens if an invalid `cat` argument is passed."
    ) 
[docs]def get_group(table, graph, column="agasc_id"):
    """
    Get group index given by the connected component of a graph.
    The following example can serve as test case:
        >>> t = Table()
        >>> t['agasc_id'] = [2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
        >>> graph = nx.Graph()
        >>> graph.add_nodes_from([3, 4, 6, 7, 8, 9, 10])
        >>> graph.add_edges_from([
        >>>     (3, 4), (4, 6), (8, 10)
        >>> ])
        >>> print("components:", (list(nx.connected_components(graph))))
        >>> print("groups:", get_group(t, graph))
        >>> assert np.all(
        ...     get_group(t, graph) == np.array( [-2,  0,  0, -2,  0, -1,  1, -1,  1, -2])
        ... )
        components: [{3, 4, 6}, {5}, {7}, {8, 10}, {9}]
        groups: [-2  0  0 -2  0 -1  1 -1  1 -2]
    Parameters
    ----------
    table : astropy.table.Table
        Table to group
    graph : networkx.Graph
        Graph to use for grouping.
    column : str
        Column to use for grouping. Default: 'agasc_id'
    Returns
    -------
    np.array : group index for each row in the table
        The groups are the index of its connected component in the graph (if cardinality > 1).
        Rows in connected components with cardinality 1 are given the group ID -1.
        Rows not in the graph are given the group ID -2.
    """
    connected_components = [
        comp for comp in nx.connected_components(graph) if len(comp) > 1
    ]
    comps = {sid: i for i, comp in enumerate(connected_components) for sid in comp}
    sids = sorted(comps)
    sids = np.array(sids)
    comps = np.array([comps[sid] for sid in sids])
    idx = np.searchsorted(sids, table[column])
    # idx[idx == idx.max()] = idx.max() - 1
    zero = np.zeros(len(table), dtype=int)
    found = idx < len(sids)
    ok = found & (sids[np.where(found, idx, zero)] == table[column])
    group = -1 * np.ones(len(table), dtype=int)
    group[ok] = comps[idx[ok]]
    group[~np.in1d(table[column], graph.nodes())] = -2
    return group 
[docs]def get_collision_graph(matches):
    """
    Return a graph that represents cross-matching collisions.
    This is called by get_agasc_gaia_x_match_difficult.
    A collision occurs when two AGASC stars are matched to the same Gaia star.
    In the collision graph, nodes are AGASC IDs and edges are collisions.
    The collision graph is constructed as a "bipartite" graph, with AGASC stars on one side and Gaia
    stars on the other, where edges represent cross-matches. Edges are then added between AGASC
    stars that are linked to the same Gaia star, and all Gaia stars are removed.
    Uses the following columns in the input:
    - agasc_id
    - gaia_id
    - d2d
    - d_mag
    - p_match
    - p_value
    Parameters
    ----------
    matches : astropy.table.Table
          Table of matches, as returned by agasc_gaia.cross_match.get_agasc_gaia_x_match_all()
    Returns
    -------
    graph : networkx.Graph
          A graph where nodes are AGASC stars and edges represent collisions. Edges have the
          following attributes:
          - p_match: probability of the cross-match
          - d2d: angular distance between the AGASC and Gaia stars
          - d_mag: difference in magnitude between the AGASC and Gaia stars
          - p_value: probability of the cross-match, as returned by the cross-match algorithm
    """
    # this is how the final match stands right now
    table = matches[matches["best_match"]]
    # for the purpose of identifying difficult stars, we consider only nearby stars
    # it can happen that a "best_match" for a star is not a true match, and we do not want noise
    # from those.
    table = table[table["d2d"] < 5]
    # filter to select only the entries that have gaia_id collisions
    grouped_table = table.group_by("gaia_id")
    n = np.diff(grouped_table.groups.indices)
    table = grouped_table.groups[n > 1]
    # Convert the agasc_gaia_x_match_all to a graph linking AGASC and Gaia stars:
    graph = nx.Graph()
    graph.add_nodes_from(np.unique(np.asarray(table["agasc_id"])), bipartite=0)
    graph.add_nodes_from(np.unique(np.asarray(table["gaia_id"])), bipartite=1)
    graph.add_edges_from(
        [
            (
                agasc_id,
                gaia_id,
                {"p_match": p_match, "d2d": d2d, "d_mag": d_mag, "p_value": p_value},
            )
            for agasc_id, gaia_id, p_match, d2d, d_mag, p_value in np.asarray(
                table["agasc_id", "gaia_id", "p_match", "d2d", "d_mag", "p_value"]
            )
        ]
    )
    # AGASC stars are linked through a Gaia star,
    # The weight of that link is the weakest link between the two AGASC-Gaia links
    graph.add_edges_from(
        [
            (
                p1,
                p2,
                {
                    "p_match": np.min([attrs1["p_match"], attrs2["p_match"]]),
                    "d2d": np.max([attrs1["d2d"], attrs2["d2d"]]),
                    "d_mag": np.max([attrs1["d_mag"], attrs2["d_mag"]]),
                    "p_value": np.min([attrs1["p_value"], attrs2["p_value"]]),
                },
            )
            for node, attrs in graph.nodes(data=True)
            for (_, p1, attrs1), (_, p2, attrs2) in itertools.combinations(
                graph.edges(node, data=True), 2
            )
            if attrs["bipartite"] == 1
        ]
    )
    # now remove the Gaia stars
    graph.remove_nodes_from(
        {node for node, attrs in graph.nodes(data=True) if attrs["bipartite"] == 1}
    )
    return graph 
@utils.cached(name="agasc-gaia-x-match")
def get_agasc_gaia_x_match():
    # Some gaia IDs are repeated. Those we will call "difficult" stars.
    # They occur because two AGASC stars are matched to the same Gaia star.
    # This collision is fixed in get_agasc_gaia_x_match_difficult
    difficult = get_agasc_gaia_x_match_difficult()
    data = get_agasc_gaia_x_match_all()
    data["idx"] = np.arange(len(data))
    # we use np.searchsorted instead of join because it is faster
    idx = np.searchsorted(data["idx"], difficult["idx"])
    # sanity check
    assert np.all(data["agasc_id"][idx] == difficult["agasc_id"])
    assert np.all(data["gaia_id"][idx] == difficult["gaia_id"])
    data["best_match"][idx] = difficult["best_match"]
    final_matches = data[data["best_match"]]
    final_matches = final_matches[
        (final_matches["d2d"] < 5)
        & (final_matches["p_value"] > 0.02)
    ]
    # Here we check that there are no repeated Gaia IDs
    g = final_matches.group_by("gaia_id")
    n = g.groups.indices[1:] - g.groups.indices[:-1]
    assert np.all(n == 1)
    return final_matches
def do_before(data):
    agasc_summary = datasets.get_agasc_summary()
    print("Joining AGASC and Gaia data")
    #  the indirect method have these
    for col in ["ra_agasc", "dec_agasc", "epoch_agasc", "pm_agasc", "tycho_id"]:
        if col in data.colnames:
            data.remove_column(col)
    cols = [
        "agasc_id",
        # used to get d2d (there will be _gaia and _agasc versions of these columns)
        "ra",
        "dec",
        "pm_ra",
        "pm_dec",
        "epoch",
        # used to get d_mag (only from agasc summary)
        "mag",
        "mag_catid",
        "mag_band",
        "tycho_id",
    ]
    data = join(
        agasc_summary[cols],
        data,
        keys=["agasc_id"],
        join_type="inner",
        table_names=["agasc", "gaia"],
    )
    has_rp_mag = ~data["rp_mag"].mask
    has_bp_mag = ~data["bp_mag"].mask
    has_g_mag = ~data["g_mag"].mask
    data["has_mag"] = 1 * has_g_mag + 2 * has_rp_mag + 4 * has_bp_mag
    data["phot_variable_flag"] = data["phot_variable_flag"] == "VARIABLE"
    data["ave_gaia_mag"] = MaskedColumn(np.zeros(len(data)), mask=np.ones(len(data)))
    data["ave_gaia_mag"][data["has_mag"] == 1] = data["g_mag"][data["has_mag"] == 1]
    data["ave_gaia_mag"][data["has_mag"] == 2] = data["rp_mag"][data["has_mag"] == 2]
    data["ave_gaia_mag"][data["has_mag"] == 4] = data["bp_mag"][data["has_mag"] == 4]
    data["ave_gaia_mag"][data["has_mag"] == 3] = (
        data["g_mag"][data["has_mag"] == 3] + data["rp_mag"][data["has_mag"] == 3]
    ) / 2
    data["ave_gaia_mag"][data["has_mag"] == 5] = (
        data["g_mag"][data["has_mag"] == 5] + data["bp_mag"][data["has_mag"] == 5]
    ) / 2
    data["ave_gaia_mag"][data["has_mag"] == 6] = (
        data["rp_mag"][data["has_mag"] == 6] + data["bp_mag"][data["has_mag"] == 6]
    ) / 2
    data["ave_gaia_mag"][data["has_mag"] == 7] = (
        data["g_mag"][data["has_mag"] == 7]
        + data["rp_mag"][data["has_mag"] == 7]
        + data["bp_mag"][data["has_mag"] == 7]
    ) / 3
    if "gsc2.3" in data.colnames and "gsc23" not in data.colnames:
        data.rename_column("gsc2.3", "gsc23")
    # fixing some stupid numbers in AGASC
    data["pm_dec_agasc"][data["pm_dec_agasc"] == -9999] = 0
    data["pm_ra_agasc"][data["pm_ra_agasc"] == -9999] = 0
    if np.any(
        (~data["tycho_id"].mask)
        & ((data["epoch_agasc"] == -9999) | (data["epoch_agasc"] == 2000.0))
    ):
        data["epoch_agasc"][
            (~data["tycho_id"].mask)
            & ((data["epoch_agasc"] == -9999) | (data["epoch_agasc"] == 2000.0))
        ] = 1991.5
    print("Setting Gaia positions at AGASC epoch")
    # positions at the epoch of the agasc catalog
    ra = np.array(data["ra_agasc"])
    dec = np.array(data["dec_agasc"])
    data["coord_agasc"] = SkyCoord(ra=ra, dec=dec, unit="deg")
    ra = np.array(data["ra_gaia"])
    dec = np.array(data["dec_gaia"])
    pm = ~data["pm_ra_gaia"].mask & ~data["pm_dec_gaia"].mask
    ra[pm] += (
        data["pm_ra_gaia"][pm]
        * (data["epoch_agasc"][pm] - data["epoch_gaia"][pm])
        / 1000
        / 3600
        / np.cos(np.deg2rad(data["dec_gaia"][pm]))
    )
    dec[pm] += (
        data["pm_dec_gaia"][pm]
        * (data["epoch_agasc"][pm] - data["epoch_gaia"][pm])
        / 1000
        / 3600
    )
    data["coord_gaia"] = SkyCoord(ra=ra, dec=dec, unit="deg")
    data["d2d"] = data["coord_gaia"].separation(data["coord_agasc"]).to(u.arcsec)
    # estimated v-mag (or whatever band is in AGASC) based on Gaia magnitude
    data["mag_pred"] = gaia_queries.mag_v_gaia_mag(data)
    data["d_mag"] = data["mag_pred"] - data["mag"]
    data.remove_columns(
        [
            "coord_agasc",
            "coord_gaia",
        ]
    )
    data["p_match"] = agasc_gaia_match_probability(data["d_mag"], data["d2d"])
    data["log_p_match"] = np.inf
    data["log_p_match"][data["p_match"] > 0] = -np.log10(
        data["p_match"][data["p_match"] > 0]
    )
    return data
def do_neighbors(data):
    data = data.group_by("agasc_id")
    # n = np.diff(data.groups.indices)
    g_index = np.arange(len(data.groups))
    indices = np.array(list(zip(data.groups.indices[:-1], data.groups.indices[1:])))
    cols = data["ave_gaia_mag"].data.data
    # Find the brightest neighbor and how much each neighbor contributes to magnitude
    min_neighbor_mag = np.zeros(len(data))
    brightest_neighbor = np.zeros(len(data), dtype=bool)
    for i in tqdm.tqdm(g_index):
        i0, i1 = indices[i]
        bn = np.argmin(cols[i0:i1])
        min_neighbor_mag[i0:i1] = cols[i0 + bn]
        brightest_neighbor[i0 + bn] = True
    data["brightest_neighbor"] = brightest_neighbor
    data["min_neighbor_mag"] = min_neighbor_mag
    data["neighbor_mag_weight"] = np.exp(data["min_neighbor_mag"]) / np.exp(
        data["ave_gaia_mag"]
    )
    # determine the number of neighbors contributing to magnitude
    n_mag_neighbors = np.zeros(len(data), dtype=int)
    neighbor_mag_weight = data["neighbor_mag_weight"].data.data
    has_mag = ~data["ave_gaia_mag"].mask
    for i in tqdm.tqdm(g_index):
        i0, i1 = indices[i]
        n_mag_neighbors[i0:i1] = np.count_nonzero(
            has_mag[i0:i1] & (neighbor_mag_weight[i0:i1] > 0.1)
        )
    data["n_mag_neighbors"] = n_mag_neighbors
    return data
def get_p_value_function():
    print("Getting p-value function")
    print("- Defining sampling space")
    n_grid = 3000  # this drives the time
    max_d_mag = 20
    d_mag = np.linspace(-max_d_mag, max_d_mag, 2 * n_grid + 1)
    # d2d = np.logspace(-4, 1, n_grid+1)
    d2d = np.linspace(0, 10, n_grid + 1)
    d_d_mag, d_d2d = np.diff(d_mag), np.diff(d2d)
    d_mag = (d_mag[1:] + d_mag[:-1]) / 2
    d2d = (d2d[1:] + d2d[:-1]) / 2
    d_mag, d2d = np.meshgrid(d_mag, d2d)
    d_d_mag, d_d2d = np.meshgrid(d_d_mag, d_d2d)
    d_area = d_d_mag * d_d2d * d2d  # d2d is the 2d radius, so the area is approx r dr)
    print("- Calculating match probability")
    p_match = agasc_gaia_match_probability(d_mag, d2d)
    norm_p_match = p_match / np.sum(p_match * d_area)
    print("- Generating Random sample")
    # now generate a sample according to that distribution
    # up-sample the tails, so the probability never goes below the given value
    p_sample_min = 1e-3
    p_sample_min = p_sample_min * norm_p_match.max()
    sample_prob = np.where(norm_p_match > p_sample_min, norm_p_match, p_sample_min)
    sample_weights = norm_p_match / sample_prob
    sample_prob = sample_prob / np.sum(sample_prob)  # need to normalize for np.random
    # in general, the number of samples determines how much we explore the tails and how close we
    # get to zero, but thanks to up-sampling, the tails will be sampled, so we do not need millions
    # of samples
    idx = np.random.choice(
        np.arange(len(norm_p_match.flatten())), size=100000, p=(sample_prob).flatten()
    )
    # d_mag_sample = d_mag.flatten()[idx]
    # d2d_sample = d2d.flatten()[idx]
    p_match_sample = p_match.flatten()[idx]
    p_match_sample_weights = sample_weights.flatten()[idx]
    eps = np.finfo(p_match_sample.dtype).eps
    print("- Calculating CDF")
    bins = np.logspace(
        np.log(p_match_sample[p_match_sample > 0].min() * (1 - eps)),
        np.log(p_match_sample.max() * (1 + eps)),
        1000000,  # the number of bins directly affects the lowest p_value we will get
    )
    vals, bins = np.histogram(p_match_sample, bins=bins, weights=p_match_sample_weights)
    x = (bins[1:] + bins[:-1]) / 2
    n = np.cumsum(vals) / np.sum(vals)
    return scipy.interpolate.interp1d(
        x, n, fill_value=(np.min(n), np.max(n)), bounds_error=False
    )
def do_after(data):
    get_p_value = get_p_value_function()
    mask = (
        data["p_match"].mask
        if hasattr(data["p_match"], "mask")
        else np.zeros(len(data), dtype=bool)
    )
    data["p_value"] = np.where(mask, 0, get_p_value(np.asarray(data["p_match"][~mask])))
    data["log_p_value"] = np.inf
    data["log_p_value"][data["p_value"] > 0] = -np.log10(
        data["p_value"][data["p_value"] > 0]
    )
    data = data.group_by("agasc_id")
    # n = np.diff(data.groups.indices)
    g_index = np.arange(len(data.groups))
    indices = np.array(list(zip(data.groups.indices[:-1], data.groups.indices[1:])))
    cols = np.asarray(data["p_match"])
    p_relative = np.zeros(len(data))
    for i in tqdm.tqdm(g_index):
        i0, i1 = indices[i]
        sum = np.nansum(cols[i0:i1])
        p_relative[i0:i1] = (cols[i0:i1] / sum) if sum > 0 else 0
    data["p_relative"] = p_relative
    data["log_p_relative"] = np.inf
    data["log_p_relative"][data["p_relative"] > 0] = -np.log10(
        data["p_relative"][data["p_relative"] > 0]
    )
    fmts = {
        "min_neighbor_mag": ".2f",
        "neighbor_mag_weight": ".2f",
        "ave_gaia_mag": ".2f",
        "mag_aca": ".2f",
        "pm_ra": ".2f",
        "pm_dec": ".2f",
        "g_mag": ".2f",
        "bp_mag": ".2f",
        "rp_mag": ".2f",
        "angular_distance": ".2f",
        "pm_agasc": ".1f",
        "d2d": ".3f",
        "mag_1p7": ".2f",
        "mag_1p8": ".2f",
        "pm_ra_gaia": ".2f",
        "pm_dec_gaia": ".2f",
        "pm_ra_agasc": ".2f",
        "pm_dec_agasc": ".2f",
        "mag_aca_1p7": ".2f",
        "mag_aca_1p8": ".2f",
        "mag_pred": ".2f",
        "p_value": ".4f",
        "p_match": ".4f",
        "p_relative": ".4f",
        "log_p_value": ".5f",
        "d_mag": ".2f",
    }
    data.rename_columns(
        ["ra_gaia", "dec_gaia", "pm_ra_gaia", "pm_dec_gaia", "epoch_gaia"],
        ["ra", "dec", "pm_ra", "pm_dec", "epoch"],
    )
    for name, fmt in fmts.items():
        if name in data.colnames:
            data[name].format = fmt
    print("re-casting result")
    dtype = [dt for dt in XMATCH_DTYPE if dt[0] in data.colnames]
    cols = [col[0] for col in dtype]
    data = Table(data[cols].as_array().astype(dtype))
    return data
def add_summary_cols(data):
    agasc_summary = datasets.get_agasc_summary()
    i = np.searchsorted(agasc_summary["agasc_id"], data["agasc_id"])
    cols = [
        "mag_aca",
        "mag_aca_err",
        "mag_aca_obs",
        "mag_aca_err_obs",
        "mag_catid",
        "mag_band",
        "random_index",
    ]
    for col in cols:
        data[col] = agasc_summary[col][i]
def split_train_test(data):
    # I am creating n_subsamples subsamples that will be used for cross-validation and testing
    n_subsamples = 15
    data["x_val_sample"] = data["random_index"] % n_subsamples
    test = data[
        (~data["mag_aca_obs"].mask) & (data["x_val_sample"] >= n_subsamples - 3)
    ]
    train = data[
        (~data["mag_aca_obs"].mask) & (data["x_val_sample"] < n_subsamples - 3)
    ]
    return train, test
@utils.cached(name="agasc-gaia-x-match-train-test")
def get_agasc_gaia_x_match_train_test():
    data = get_agasc_gaia_x_match()
    add_summary_cols(data)
    return split_train_test(data)
@utils.cached(name="agasc-gsc-tycho-gaia-x-match-train-test")
def get_agasc_tycho_gsc_gaia_x_match_train_test():
    data = get_agasc_tycho_gsc_gaia_x_match()
    add_summary_cols(data)
    return split_train_test(data)