# SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import numpy as np
import pandas as pd

from nsys_recipe.lib import data_utils, overlap

DEFAULT_DOMAIN_ID = 0
EVENT_TYPE_NVTX_DOMAIN_CREATE = 75
EVENT_TYPE_NVTX_PUSHPOP_RANGE = 59
EVENT_TYPE_NVTX_STARTEND_RANGE = 60


def filter_by_domain_id(nvtx_df, domain_id):
    """Get push/pop and start/end ranges with the specified 'domain_id'."""
    return nvtx_df[
        (nvtx_df["domainId"] == domain_id)
        & (
            nvtx_df["eventType"].isin(
                [EVENT_TYPE_NVTX_PUSHPOP_RANGE, EVENT_TYPE_NVTX_STARTEND_RANGE]
            )
        )
    ]


def filter_by_domain_name(nvtx_df, domain_name):
    """Get push/pop and start/end ranges with the specified 'domain_name'."""
    domain_df = nvtx_df[
        (nvtx_df["eventType"] == EVENT_TYPE_NVTX_DOMAIN_CREATE)
        & (nvtx_df["text"] == domain_name)
    ]

    if domain_df.empty:
        return domain_df

    domain_id = domain_df["domainId"].iloc[0]

    return filter_by_domain_id(nvtx_df, domain_id)


def combine_text_fields(nvtx_df, str_df):
    """Combine the 'text' and 'textId' fields of the NVTX dataframe.

    This function simplifies the lookup process for accessing the event
    message. The 'text' field corresponds to the NVTX event message passed
    through 'nvtxDomainRegisterString', while the 'textId' field corresponds
    to the other case. By merging these fields, we streamline the process of
    retrieving the message.
    """
    if not nvtx_df["textId"].notnull().any():
        return nvtx_df.copy()

    nvtx_textId_df = data_utils.replace_id_with_value(
        nvtx_df, str_df, "textId", "textStr"
    )
    mask = ~nvtx_textId_df["textStr"].isna()
    nvtx_textId_df.loc[mask, "text"] = nvtx_textId_df.loc[mask, "textStr"]
    return nvtx_textId_df.drop(columns=["textStr"])


def compute_hierarchy_info(nvtx_df):
    """Compute the hierarchy information of each NVTX range.

    This function assumes that the input DataFrame is sorted by times. It
    will add the following columns to the DataFrame:
    - stackLevel: level of the range in the stack.
    - parentId: ID of the parent range.
    - rangeStack: IDs of the ranges that make up the stack.
    - childrenCount: number of child ranges.
    - rangeId: arbitrary ID of the range.
    """
    hierarchy_df = nvtx_df.copy()

    hierarchy_df["parentId"] = None
    hierarchy_df["stackLevel"] = 0
    hierarchy_df["rangeStack"] = None

    stack = []

    for row in hierarchy_df.itertuples():
        while stack and stack[-1].end <= row.start:
            stack.pop()

        parent_index = stack[-1].Index if stack else np.nan
        stack.append(row)

        hierarchy_df.at[row.Index, "parentId"] = parent_index
        # The stack level starts at 0.
        hierarchy_df.at[row.Index, "stackLevel"] = len(stack) - 1
        hierarchy_df.at[row.Index, "rangeStack"] = [r.Index for r in stack]

    hierarchy_df = hierarchy_df.reset_index().rename(columns={"index": "rangeId"})

    children_count = hierarchy_df["parentId"].value_counts()
    hierarchy_df["childrenCount"] = (
        hierarchy_df["rangeId"].map(children_count).fillna(0).astype(int)
    )

    return hierarchy_df


def _add_individual_events(nvtx_gpu_start_dict, nvtx_gpu_end_dict, list_of_individuals):
    indices, starts, ends = [], [], []

    for index, start, end in list_of_individuals:
        if index in nvtx_gpu_start_dict:
            # The range is already included in an existing range. We skip it.
            if start >= nvtx_gpu_start_dict[index] and end <= nvtx_gpu_end_dict[index]:
                continue
            # The range is partially included in an existing range. We extend
            # the existing range to include it.
            elif (
                start >= nvtx_gpu_start_dict[index]
                and start <= nvtx_gpu_end_dict[index]
            ):
                nvtx_gpu_end_dict[index] = end
                continue
            elif end >= nvtx_gpu_start_dict[index] and end <= nvtx_gpu_end_dict[index]:
                nvtx_gpu_start_dict[index] = start
                continue
        # The range is not included in an existing range. We add it as a
        # new range.
        indices.append(index)
        starts.append(start)
        ends.append(end)

    indices += list(nvtx_gpu_start_dict.keys())
    starts += list(nvtx_gpu_start_dict.values())
    ends += list(nvtx_gpu_end_dict.values())

    return indices, starts, ends


def _compute_gpu_projection_df(nvtx_df, cuda_df, cuda_nvtx_index_map):
    # Each NVTX index will be associated with the minimum start time and the
    # maximum end time of the CUDA operations that the corresponding NVTX range
    # encloses.
    nvtx_gpu_start_dict = {}
    nvtx_gpu_end_dict = {}
    # list_of_individuals contains NVTX indices that should not be grouped.
    # These items will be treated individually, using their original
    # start and end times without aggregation.
    list_of_individuals = []

    for cuda_row in cuda_df.itertuples():
        if cuda_row.Index not in cuda_nvtx_index_map:
            continue

        nvtx_indices = cuda_nvtx_index_map[cuda_row.Index]
        for nvtx_index in nvtx_indices:
            if hasattr(cuda_row, "groupId") and not pd.isna(cuda_row.groupId):
                list_of_individuals.append(
                    (nvtx_index, cuda_row.gpu_start, cuda_row.gpu_end)
                )
                continue
            if nvtx_index not in nvtx_gpu_start_dict:
                nvtx_gpu_start_dict[nvtx_index] = cuda_row.gpu_start
                nvtx_gpu_end_dict[nvtx_index] = cuda_row.gpu_end
                continue
            if cuda_row.gpu_start < nvtx_gpu_start_dict[nvtx_index]:
                nvtx_gpu_start_dict[nvtx_index] = cuda_row.gpu_start
            if cuda_row.gpu_end > nvtx_gpu_end_dict[nvtx_index]:
                nvtx_gpu_end_dict[nvtx_index] = cuda_row.gpu_end

    indices, starts, ends = _add_individual_events(
        nvtx_gpu_start_dict, nvtx_gpu_end_dict, list_of_individuals
    )

    df = pd.DataFrame(
        {"text": nvtx_df.loc[indices, "text"], "start": starts, "end": ends}
    ).reset_index()

    # Preserve original order for rows with identical "start" and "end" values
    # using the index.
    return (
        df.sort_values(by=["start", "end", "index"], ascending=[True, False, True])
        .drop(columns=["index"])
        .reset_index(drop=True)
    )


def _compute_grouped_gpu_projection_df(
    nvtx_df, cuda_df, cuda_nvtx_index_map, per_gpu=False, per_stream=False
):
    group_by_elements = []
    if per_stream:
        group_by_elements.append("streamId")
    if per_gpu:
        group_by_elements.append("deviceId")

    if not group_by_elements:
        df = _compute_gpu_projection_df(nvtx_df, cuda_df, cuda_nvtx_index_map)
        return df if not df.empty else None

    dfs = []
    cuda_gdf = cuda_df.groupby(group_by_elements)

    for group_keys, cuda_group_df in cuda_gdf:
        df = _compute_gpu_projection_df(nvtx_df, cuda_group_df, cuda_nvtx_index_map)
        if df.empty:
            continue

        if per_stream:
            df["streamId"] = group_keys[group_by_elements.index("streamId")]
        if per_gpu:
            df["deviceId"] = group_keys[group_by_elements.index("deviceId")]
        dfs.append(df)

    if not dfs:
        return None

    return pd.concat(dfs, ignore_index=True)


def project_nvtx_onto_gpu(nvtx_df, cuda_df, per_gpu=False, per_stream=False):
    """Project the NVTX ranges from the CPU onto the GPU.

    The projected range will have the start timestamp of the first enclosed GPU
    operation and the end timestamp of the last enclosed GPU operation.

    Returns
    -------
    proj_nvtx_df : pd.DataFrame or None
        DataFrame with projected NVTX ranges, or None if none are found.
    """
    # Filter ranges that are incomplete or end on a different thread.
    filtered_nvtx_df = nvtx_df[
        nvtx_df["start"].notnull()
        & nvtx_df["end"].notnull()
        & nvtx_df["endGlobalTid"].isnull()
    ]

    nvtx_gdf = filtered_nvtx_df.groupby("globalTid")
    cuda_gdf = cuda_df.groupby("globalTid")

    dfs = []

    for global_tid, nvtx_tid_df in nvtx_gdf:
        if global_tid not in cuda_gdf.groups:
            continue

        cuda_tid_df = cuda_gdf.get_group(global_tid)
        cuda_nvtx_index_map = overlap.map_overlapping_ranges(
            nvtx_tid_df, cuda_tid_df, fully_contained=True
        )

        df = _compute_grouped_gpu_projection_df(
            filtered_nvtx_df, cuda_tid_df, cuda_nvtx_index_map, per_gpu, per_stream
        )
        if df is None:
            continue

        # The values of pid and tid are the same within each group of globalTid.
        df["pid"] = nvtx_tid_df["pid"].iat[0]
        df["tid"] = nvtx_tid_df["tid"].iat[0]

        dfs.append(df)

    if not dfs:
        return None

    return pd.concat(dfs, ignore_index=True)


def classify_cuda_kernel(nccl_df, cuda_df):
    """Classify CUDA kernels.

    The resulting DataFrame will contain GPU information, with an extra
    column called 'type' that will either be 'nccl' or 'compute'.
    """
    # Filter ranges that are incomplete or end on a different thread.
    filtered_nccl_df = nccl_df[
        nccl_df["start"].notnull()
        & nccl_df["end"].notnull()
        & nccl_df["endGlobalTid"].isnull()
    ]

    nccl_gdf = filtered_nccl_df.groupby("globalTid")

    type_df = cuda_df.assign(type="compute")
    cuda_gdf = type_df.groupby("globalTid")

    dfs = []

    for global_tid, nvtx_tid_df in nccl_gdf:
        if global_tid not in cuda_gdf.groups:
            continue

        cuda_tid_df = cuda_gdf.get_group(global_tid)
        cuda_nvtx_index_map = overlap.map_overlapping_ranges(
            nvtx_tid_df, cuda_tid_df, fully_contained=True
        )

        nccl_indices = cuda_nvtx_index_map.keys()

        cuda_tid_df.loc[cuda_tid_df.index.isin(nccl_indices), "type"] = "nccl"

        dfs.append(cuda_tid_df.sort_values(["gpu_start", "gpu_end"]))

    if dfs:
        type_df = pd.concat(dfs, ignore_index=True)

    type_df = type_df.drop(columns=["start", "end"])
    return type_df.rename(columns={"gpu_start": "start", "gpu_end": "end"})
