# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "matplotlib",
#     "pandas",
# ]
# ///
import argparse
import json
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


def plot_pupil_data(eye_states_path, blinks_path, info_path):
    try:
        eye_df = pd.read_csv(eye_states_path)
        blinks_df = pd.read_csv(blinks_path)
        with open(info_path) as f:
            info_json = json.load(f)
    except FileNotFoundError as e:
        print(f"Error: {e}. Please ensure the files exist in the specified folder.")
        sys.exit(1)
    except Exception as e:
        print(f"An error occurred: {e}")
        sys.exit(1)

    required_eye_cols = ["timestamp [ns]", "pupil diameter left [mm]"]
    required_blink_cols = ["start timestamp [ns]", "end timestamp [ns]"]
    if not all(col in eye_df.columns for col in required_eye_cols):
        print(f"Error: Eye states CSV must contain: {required_eye_cols}")
        sys.exit(1)
    if not all(col in blinks_df.columns for col in required_blink_cols):
        print(f"Error: Blinks CSV must contain: {required_blink_cols}")
        sys.exit(1)
    if "start_time" not in info_json:
        print("Error: info.json must contain the key 'start_time'")
        sys.exit(1)

    start_timestamp = info_json["start_time"]

    eye_df["timestamp_relative_s"] = (eye_df["timestamp [ns]"] - start_timestamp) / 1e9
    blinks_df["start_relative_s"] = (
        blinks_df["start timestamp [ns]"] - start_timestamp
    ) / 1e9
    blinks_df["end_relative_s"] = (
        blinks_df["end timestamp [ns]"] - start_timestamp
    ) / 1e9

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 7), sharey=True, sharex=True)
    fig.suptitle("Pupil Diameter", fontsize=18)

    ax1.plot(
        eye_df["timestamp_relative_s"],
        eye_df["pupil diameter left [mm]"],
        label="Left Pupil Diameter",
        color="cornflowerblue",
    )

    for _, blink_row in blinks_df.iterrows():
        ax1.axvspan(
            blink_row["start_relative_s"],
            blink_row["end_relative_s"],
            color="lightgray",
            alpha=0.6,
            zorder=0,
            label="_nolegend_",
        )

    ax1.set_title("Pupil Diameter with Blink Events")
    ax1.set_xlabel("Time (seconds)")
    ax1.set_ylabel("Pupil Diameter (mm)")
    ax1.legend()
    ax1.grid(True, linestyle="--", alpha=0.6)

    pupil_data_with_gaps = eye_df.copy()

    for _, blink_row in blinks_df.iterrows():
        blink_mask = (
            pupil_data_with_gaps["timestamp_relative_s"]
            >= blink_row["start_relative_s"]
        ) & (
            pupil_data_with_gaps["timestamp_relative_s"] <= blink_row["end_relative_s"]
        )
        pupil_data_with_gaps.loc[blink_mask, "pupil diameter left [mm]"] = np.nan

    ax2.plot(
        pupil_data_with_gaps["timestamp_relative_s"],
        pupil_data_with_gaps["pupil diameter left [mm]"],
        label="Left Pupil Diameter (Blinks Removed)",
        color="mediumseagreen",
    )

    ax2.set_title("Pupil Diameter (Blinks Filtered Out)")
    ax2.set_xlabel("Time (seconds)")
    ax2.legend()
    ax2.grid(True, linestyle="--", alpha=0.6)

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Plot pupil diameter data from a specified folder."
    )
    parser.add_argument(
        "folder", type=str, help="The path to the folder containing the data files."
    )
    args = parser.parse_args()

    folder_path = args.folder

    eye_states_file_path = os.path.join(folder_path, "3d_eye_states.csv")
    blinks_file_path = os.path.join(folder_path, "blinks.csv")
    info_file_path = os.path.join(folder_path, "info.json")

    plot_pupil_data(eye_states_file_path, blinks_file_path, info_file_path)
