import glob
import os
from fractions import Fraction

import av
import cv2
import numpy as np
import pandas as pd
from pupil_labs.dynamic_content_on_rim.uitools.ui_tools import get_savedir
from pupil_labs.dynamic_content_on_rim.video.read import get_frame, read_video_ts

input_path = ""
circle_size = 10

oftype = {"timestamp [ns]": np.uint64}
world_timestamps_df = pd.read_csv(
    os.path.join(input_path, "world_timestamps.csv"), dtype=oftype
)
gaze_df = pd.read_csv(os.path.join(input_path, "gaze.csv"), dtype=oftype)
files = glob.glob(os.path.join(input_path, "*.mp4"))
video_path = files[0]
_, frames, pts, ts = read_video_ts(video_path)

with av.open(video_path) as v:
    if not v.streams.audio:
        print("No audio stream found!")
        audio_stream_available = False
    else:
        audio_stream_available = True

if audio_stream_available:
    _, audio_frames, audio_pts, audio_ts = read_video_ts(
        video_path, audio=True, auto_thread_type=False
    )
ts = world_timestamps_df["timestamp [ns]"]

video_df = pd.DataFrame(
    {
        "frames": np.arange(frames),
        "pts": [int(pt) for pt in pts],
        "timestamp [ns]": ts,
    }
)
if audio_stream_available:
    audio_ts = audio_ts + ts[0]
    audio_df = pd.DataFrame(
        {
            "frames": np.arange(audio_frames),
            "pts": [int(pt) for pt in audio_pts],
            "timestamp [ns]": audio_ts,
        }
    )

merged_video = pd.merge_asof(
    video_df,
    gaze_df,
    on="timestamp [ns]",
    direction="nearest",
    suffixes=["video", "gaze"],
)

if audio_stream_available:
    merged_audio = pd.merge_asof(
        audio_df,
        video_df,
        on="timestamp [ns]",
        direction="nearest",
        suffixes=["audio", "video"],
    )

with av.open(video_path) as vid_container:
    print("Reading first frame")
    vid_frame = next(vid_container.decode(video=0))
    if audio_stream_available:
        aud_frame = next(vid_container.decode(audio=0))

num_processed_frames = 0

output_file = get_savedir(None, type="video")
output_path = os.path.split(output_file)[0]
out_csv = os.path.join(output_path, "merged.csv")

with av.open(video_path) as video, av.open(video_path) as audio, av.open(
    output_file, "w"
) as out_container:
    out_video = out_container.add_stream("libx264", rate=30, options={"crf": "18"})
    out_video.width = video.streams.video[0].width
    out_video.height = video.streams.video[0].height
    out_video.pix_fmt = "yuv420p"
    out_video.codec_context.time_base = Fraction(1, 30)
    if audio_stream_available:
        out_audio = out_container.add_stream("aac", layout="stereo")
        out_audio.rate = audio.streams.audio[0].rate
        out_audio.time_base = out_audio.codec_context.time_base
    lpts = -1
    while num_processed_frames < merged_video.shape[0]:
        row = merged_video.iloc[num_processed_frames]
        # Get the frame
        vid_frame, lpts = get_frame(video, int(row["pts"]), lpts, vid_frame)
        if vid_frame is None:
            break
        img_original = vid_frame.to_ndarray(format="rgb24")
        # Prepare the frame
        frame = cv2.cvtColor(img_original, cv2.COLOR_RGB2BGR)
        frame = np.asarray(frame, dtype=np.float32)
        frame = frame[:, :, :]
        xy = row[["gaze x [px]", "gaze y [px]"]].to_numpy(dtype=np.int32)

        if not np.isnan(xy).any():
            cv2.circle(frame, xy, circle_size, (0, 0, 255), 10)

        out_ = cv2.normalize(frame, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
        cv2.cvtColor(out_, cv2.COLOR_BGR2RGB, out_)
        np.expand_dims(out_, axis=2)
        out_frame = av.VideoFrame.from_ndarray(out_, format="rgb24")
        for packet in out_video.encode(out_frame):
            out_container.mux(packet)
        if num_processed_frames % 100 == 0:
            print(
                f"Processed {num_processed_frames} frames out of {merged_video.shape[0]}"
            )
        num_processed_frames += 1
    for packet in out_video.encode(None):
        out_container.mux(packet)
    # audio
    if audio_stream_available:
        num_processed_frames = 0
        lpts = -1
        while num_processed_frames < merged_audio.shape[0]:
            row = merged_audio.iloc[num_processed_frames]
            aud_frame, lpts = get_frame(
                audio, int(row["ptsaudio"]), lpts, aud_frame, audio=True
            )
            if aud_frame is None:
                break
            aud_frame.pts = None
            af = out_audio.encode(aud_frame)
            out_container.mux(af)
            num_processed_frames += 1
        for packet in out_audio.encode(None):
            out_container.mux(packet)
    out_container.close()
    # save the csv
    merged_video.to_csv(out_csv, index=False)
    print(f"CSV file saved at: {out_csv}")
