# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "opencv-python",
#     "pl-neon-recording[examples]",
#     "pyav",
#     "rich",
#     "pandas",
# ]
# ///

import argparse
import logging
from fractions import Fraction
from pathlib import Path
from typing import Optional

import av
import cv2
import numpy as np
import pandas as pd
import pupil_labs.neon_recording as nr
from pupil_labs.neon_recording.stream import Stream
from pupil_labs.neon_recording.stream.av_stream.video_stream import GrayFrame
from rich import pretty
from rich.console import Console
from rich.logging import RichHandler
from rich.progress import track

logging.basicConfig(
    level=logging.INFO,
    format="%(message)s",
    handlers=[RichHandler()],
)
logger = logging.getLogger(__name__)
pretty.install(expand_all=True)


def overlay_image(
    img: np.ndarray, img_overlay: np.ndarray, x: int, y: int, corner_radius: int = 15
) -> None:
    """
    Overlays img_overlay onto img at position (x, y) with rounded corners.

    Parameters:
    - img (np.ndarray): The base image.
    - img_overlay (np.ndarray): The overlay image with or without alpha channel.
    - x (int): The x-coordinate for the overlay placement.
    - y (int): The y-coordinate for the overlay placement.
    - corner_radius (int): Radius for the rounded corners in pixels.
    """
    y1, y2 = max(0, y), min(img.shape[0], y + img_overlay.shape[0])
    x1, x2 = max(0, x), min(img.shape[1], x + img_overlay.shape[1])
    y1o, y2o = max(0, -y), min(img_overlay.shape[0], img.shape[0] - y)
    x1o, x2o = max(0, -x), min(img_overlay.shape[1], img.shape[1] - x)
    if y1 >= y2 or x1 >= x2 or y1o >= y2o or x1o >= x2o:
        return
    overlay = img_overlay[y1o:y2o, x1o:x2o]
    if overlay.shape[2] == 4:
        overlay_rgb = overlay[:, :, :3]
        overlay_alpha = overlay[:, :, 3] / 255.0
    else:
        overlay_rgb = overlay
        overlay_alpha = np.ones((overlay.shape[0], overlay.shape[1]), dtype=float)
    mask = np.zeros((overlay.shape[0], overlay.shape[1]), dtype=np.uint8)
    cv2.rectangle(
        mask,
        (corner_radius, 0),
        (overlay.shape[1] - corner_radius, overlay.shape[0]),
        255,
        -1,
    )
    cv2.rectangle(
        mask,
        (0, corner_radius),
        (overlay.shape[1], overlay.shape[0] - corner_radius),
        255,
        -1,
    )
    cv2.circle(mask, (corner_radius, corner_radius), corner_radius, 255, -1)
    cv2.circle(
        mask,
        (overlay.shape[1] - corner_radius - 1, corner_radius),
        corner_radius,
        255,
        -1,
    )
    cv2.circle(
        mask,
        (corner_radius, overlay.shape[0] - corner_radius - 1),
        corner_radius,
        255,
        -1,
    )
    cv2.circle(
        mask,
        (overlay.shape[1] - corner_radius - 1, overlay.shape[0] - corner_radius - 1),
        corner_radius,
        255,
        -1,
    )
    mask = (
        cv2.GaussianBlur(
            mask, (corner_radius * 2 + 1, corner_radius * 2 + 1), 0
        ).astype(float)
        / 255.0
    )
    combined_alpha = overlay_alpha * mask
    combined_alpha = cv2.merge([combined_alpha, combined_alpha, combined_alpha])
    blended = (overlay_rgb.astype(float) * combined_alpha) + (
        img[y1:y2, x1:x2].astype(float) * (1 - combined_alpha)
    )
    img[y1:y2, x1:x2] = blended.astype(np.uint8)


def plot(
    img: np.ndarray,
    data: np.ndarray,
    value_range: tuple,
    x_width: float,
    color: tuple,
    line_width: int = 2,
    x_offset: int = 1200,
    y_offset: int = 50,
) -> None:
    height = img.shape[0]
    color_bgr = (color[2], color[1], color[0])
    data_normalized = (data - value_range[0]) / (value_range[1] - value_range[0] + 1e-5)
    y_values = (1 - data_normalized) * height / 8
    x_values = np.arange(len(data[0])) * x_width
    points = np.column_stack((x_values + x_offset, y_values[0] + y_offset)).astype(int)
    for idx in range(1, len(points)):
        pt1 = tuple(points[idx - 1])
        pt2 = tuple(points[idx])
        cv2.line(img, pt1, pt2, color_bgr, line_width)


def draw_circle_on_image(
    img: np.ndarray,
    center: tuple,
    radius: int,
    color: tuple,
    width: int = 1,
) -> None:
    color_bgr = (color[2], color[1], color[0])
    cv2.circle(img, center, radius, color_bgr, width)


class BlinksStream(Stream):
    def __init__(self, name, recording):
        with Console().status("[bold green]Reading blinks...", spinner="dots3"):
            blinks_csv_path = recording._rec_dir / "blinks.csv"
            gaze_csv_path = recording._rec_dir / "gaze.csv"
            if blinks_csv_path.exists() and gaze_csv_path.exists():
                blinks_df = pd.read_csv(blinks_csv_path)
                gaze_df = pd.read_csv(gaze_csv_path)
            else:
                recording.blinks = None
                logger.warning(
                    f"Required CSV files not found at {blinks_csv_path} or {gaze_csv_path}"
                )
                return

            recording_id = recording.info["recording_id"]
            blinks_df = blinks_df[blinks_df["recording id"] == recording_id]
            gaze_df = gaze_df[gaze_df["recording id"] == recording_id]

            blinks_df["blink id"] = blinks_df["blink id"].astype(float)
            start_map = blinks_df.set_index("blink id")["start timestamp [ns]"]
            end_map = blinks_df.set_index("blink id")["end timestamp [ns]"]
            duration_map = blinks_df.set_index("blink id")["duration [ms]"]

            gaze_df["start timestamp [ns]"] = (
                gaze_df["blink id"].map(start_map).astype(np.float64)
            )
            gaze_df["end timestamp [ns]"] = (
                gaze_df["blink id"].map(end_map).astype(np.float64)
            )
            gaze_df["duration [ms]"] = (
                gaze_df["blink id"].map(duration_map).astype(np.float64)
            )
            merged_df = gaze_df.copy()
            merged_df["blink_duration_ns"] = (
                merged_df["end timestamp [ns]"] - merged_df["start timestamp [ns]"]
            )
            merged_df["t_norm"] = np.where(
                merged_df["start timestamp [ns]"].notna(),
                (merged_df["timestamp [ns]"] - merged_df["start timestamp [ns]"])
                / merged_df["blink_duration_ns"],
                np.nan,
            )

            merged_df["alpha"] = np.where(
                merged_df["t_norm"].notna(),
                (np.cos(np.pi * (merged_df["t_norm"])) + 1) / 2,
                1.0,
            )
            mean_duration = merged_df["duration [ms]"].mean()
            max_duration = merged_df["duration [ms]"].max()

            def scale_alpha_based_on_duration(duration, min_alpha=0.5, max_alpha=0.99):
                if duration >= mean_duration:
                    return min_alpha + (max_alpha - min_alpha) * (
                        max_duration - duration
                    ) / (max_duration - mean_duration)
                else:
                    return max_alpha

            merged_df["alpha"] = merged_df.apply(
                lambda row: row["alpha"]
                * scale_alpha_based_on_duration(row["duration [ms]"]),
                axis=1,
            )
            merged_df["alpha"] = merged_df["alpha"].clip(0.5, 1.0)
            merged_df.loc[merged_df["blink id"].isna(), "alpha"] = 1.0

            self._data = np.rec.array(
                [
                    (
                        ts,
                        alpha,
                        int(blink_id) if not np.isnan(blink_id) else np.nan,
                        start_ns if not np.isnan(start_ns) else np.nan,
                        end_ns if not np.isnan(end_ns) else np.nan,
                        duration_ms if not np.isnan(duration_ms) else np.nan,
                    )
                    for ts, alpha, blink_id, start_ns, end_ns, duration_ms in zip(
                        np.float64(merged_df["timestamp [ns]"] / 1e9),
                        merged_df["alpha"],
                        merged_df["blink id"],
                        merged_df["start timestamp [ns]"],
                        merged_df["end timestamp [ns]"],
                        merged_df["duration [ms]"],
                    )
                ],
                dtype=[
                    ("ts", "f8"),
                    ("alpha", "f8"),
                    ("id", "f8"),
                    ("start", "f8"),
                    ("end", "f8"),
                    ("duration_ms", "f8"),
                ],
            )

            # from datetime import datetime, timedelta

            # import matplotlib.dates as mdates
            # import matplotlib.pyplot as plt

            # plt.plot(
            #     [
            #         datetime(1970, 1, 1) + timedelta(seconds=s)
            #         for s in self._data.ts / 1e9
            #     ],
            #     self._data.alpha,
            #     label="id values",
            #     linestyle="-",
            #     color="blue",
            # )
            # plt.gca().xaxis.set_major_formatter(mdates.DateFormatter("%M:%S"))
            # plt.gca().xaxis.set_major_locator(mdates.SecondLocator(interval=60))

            # plt.show()

        super().__init__(name, recording, self._data)


def main(
    recording_dir: Path,
    fps: Optional[int] = 30,
    gaze_overlay: bool = True,
    eye_overlay: bool = True,
    plots_overlay: bool = True,
    preview: bool = False,
    offset: Optional[int] = -36,
) -> None:
    logger.info("Starting processing")
    recording = nr.load(recording_dir)
    audio_stream = recording.streams["audio"]
    rate = recording.audio.video_parts[0].container.streams.audio[0].rate
    logger.info(pretty.pretty_repr(recording.info))

    output_dir = Path(recording_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    output_video_path = output_dir / "output_video.mp4"
    output_container = av.open(str(output_video_path), mode="w")
    output_stream = output_container.add_stream(
        "libx264", rate=fps, options={"crf": "18"}
    )
    output_stream.width = recording.scene.width
    output_stream.height = recording.scene.height
    output_stream.pix_fmt = "yuv420p"
    output_stream.codec_context.time_base = Fraction(1, fps)

    output_timestamps = np.arange(
        recording.scene.ts[0], recording.scene.ts[-1], 1 / fps
    )

    audio_stream = recording.streams["audio"].video_parts[0].container.streams.audio[0]
    audio_container = recording.streams["audio"].video_parts[0].container
    out_audio = output_container.add_stream(template=audio_stream)

    recording.blinks = BlinksStream("blinks", recording)
    with Console().status("[bold green]Combining the data...", spinner="dots3"):
        combined_data = zip(
            output_timestamps,
            recording.scene.sample(output_timestamps),
            recording.eye.sample(output_timestamps),
            recording.gaze.sample(output_timestamps),
            recording.eye_state.sample(output_timestamps),
            recording.blinks.sample(output_timestamps),
        )
    margin_left, margin_bot, margin_top, margin_right = (
        25,
        recording.scene.height - 25,
        25,
        recording.scene.width - 25,
    )
    if "buffer_data" not in locals():
        pupil_os_buffer_data = []
        pupil_od_buffer_data = []
    for idx, (ts, scene_frame, eye_frame, gaze_datum, eye_state, blink) in enumerate(
        track(
            combined_data,
            total=len(output_timestamps),
            description="Processing frames",
        )
    ):
        if abs(scene_frame.ts - ts) < 2 / fps:
            frame_pixels = scene_frame.bgr
        else:
            frame_pixels = GrayFrame(recording.scene.width, recording.scene.height).bgr
        last_video_ts = ts
        if blink.alpha < 1.0:
            if abs(blink.ts - ts) < 2 / fps:
                frame_pixels = (
                    frame_pixels.astype(np.float32) * (1 - blink.alpha)
                ).astype(np.uint8)

        if gaze_overlay:
            if abs(gaze_datum.ts - ts) < 2 / fps:
                gaze_x = int(gaze_datum.x)
                gaze_y = int(gaze_datum.y)
                draw_circle_on_image(
                    frame_pixels,
                    (gaze_x, gaze_y - offset),
                    radius=15,
                    color=(255, 0, 0) if blink.alpha == 1.0 else (105, 105, 105),
                    width=3,
                )

        if eye_overlay:
            if abs(eye_frame.ts - ts) < 2 / fps:
                eye_pixels = cv2.cvtColor(eye_frame.gray, cv2.COLOR_GRAY2BGR)
            else:
                eye_pixels = GrayFrame(recording.eye.width, recording.eye.height).bgr

            eye_x = 50
            eye_y = 50
            overlay_image(frame_pixels, eye_pixels, eye_x, eye_y)

        logo = cv2.imread(recording._rec_dir / "logo.png", cv2.IMREAD_UNCHANGED)
        logo = cv2.resize(
            logo,
            None,
            fx=0.2,
            fy=0.2,
        )
        logo_alpha = logo[:, :, 3] / 255
        logo_colors = logo[:, :, :3]
        alpha_mask = np.dstack([logo_alpha] * 3)
        roi = [
            margin_bot - logo.shape[0],
            margin_bot,
            margin_right - logo.shape[1],
            margin_right,
        ]
        bkg = frame_pixels[roi[0] : roi[1], roi[2] : roi[3]]
        rect = bkg * (1 - alpha_mask) + logo_colors * alpha_mask
        frame_pixels[roi[0] : roi[1], roi[2] : roi[3]] = rect

        if plots_overlay:
            if abs(eye_state.ts - ts) < 2 / fps:
                pupil_os_buffer_data.append(eye_state.pupil_diameter_left)
                if pupil_os_buffer_data is not None:
                    plot(
                        frame_pixels,
                        data=np.array([pupil_os_buffer_data]),
                        value_range=(0, 6),
                        x_width=2.5,
                        color=(255, 140, 0),
                        x_offset=1300,
                        y_offset=80,
                    )
                pupil_od_buffer_data.append(eye_state.pupil_diameter_right)
                if pupil_od_buffer_data is not None:
                    plot(
                        frame_pixels,
                        data=np.array([pupil_od_buffer_data]),
                        value_range=(0, 6),
                        x_width=2.5,
                        color=(30, 144, 255),
                        x_offset=1300,
                        y_offset=80,
                    )
                cv2.circle(frame_pixels, (1300, 35), 3, (30, 144, 255), -1)
                cv2.putText(
                    frame_pixels,
                    "Right Pupil Diameter",
                    (1310, 40),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    0.5,
                    (30, 144, 255),
                    1,
                )

                cv2.circle(frame_pixels, (1300, 55), 3, (255, 140, 0), -1)
                cv2.putText(
                    frame_pixels,
                    "Left Pupil Diameter",
                    (1310, 60),
                    cv2.FONT_HERSHEY_SIMPLEX,
                    0.5,
                    (255, 140, 0),
                    1,
                )
            pupil_od_buffer_data = pupil_od_buffer_data[-100:]
            pupil_os_buffer_data = pupil_os_buffer_data[-100:]

        if preview:
            cv2.imshow("Preview", frame_pixels)
            if cv2.waitKey(1) & 0xFF == ord("q"):
                break

        frame = av.VideoFrame.from_ndarray(frame_pixels, format="bgr24")
        for packet in output_stream.encode(frame):
            output_container.mux(packet)
    for packet in output_stream.encode(None):
        output_container.mux(packet)

    try:
        for packet in audio_container.demux(audio_stream):
            if packet.stream.type == "audio":
                if packet.pts is not None:
                    packet_pts_in_seconds = packet.pts * packet.time_base
                    if packet_pts_in_seconds <= last_video_ts:
                        packet.stream = out_audio
                        output_container.mux(packet)
                    else:
                        break
    except StopIteration:
        pass  # No more audio packets

    output_container.close()

    if preview:
        cv2.destroyAllWindows()

    logger.info("Processing completed successfully")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process Pupil Labs recordings.")
    parser.add_argument(
        "recording_dir", type=Path, help="Path to the recording directory"
    )
    parser.add_argument(
        "--fps", type=int, default=30, help="Frames per second for output video"
    )
    parser.add_argument(
        "--no-gaze-overlay",
        action="store_false",
        dest="gaze_overlay",
        help="Disable gaze overlay",
    )
    parser.add_argument(
        "--no-eye-overlay",
        action="store_false",
        dest="eye_overlay",
        help="Disable eye overlay",
    )
    parser.add_argument(
        "--no-plots-overlay",
        action="store_false",
        dest="plots_overlay",
        help="Disable plots overlay",
    )
    parser.add_argument("--preview", action="store_true", help="Enable preview mode")
    args = parser.parse_args()
    main(
        recording_dir=args.recording_dir,
        fps=args.fps,
        gaze_overlay=args.gaze_overlay,
        eye_overlay=args.eye_overlay,
        plots_overlay=args.plots_overlay,
        preview=args.preview,
    )
