# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "pandas",
#     "rich",
# ]
# ///
import os
import sys
import tempfile
import zipfile

import pandas as pd
from rich import print as rprint


def avg_fix_duration(zip_path: str) -> pd.DataFrame:
    with (
        tempfile.TemporaryDirectory() as extract_dir,
        zipfile.ZipFile(zip_path, "r") as zip_ref,
    ):
        zip_ref.extractall(extract_dir)

        # Locate events and fixations CSVs
        events_path, fixations_path = None, None
        for root, _, files in os.walk(extract_dir):
            for f in files:
                if f == "events.csv":
                    events_path = os.path.join(root, f)
                elif f == "fixations.csv":
                    fixations_path = os.path.join(root, f)

        if not events_path or not fixations_path:
            raise FileNotFoundError("Required CSV files not found in unzipped data.")

        # Load data
        events_df = pd.read_csv(events_path)
        rprint("Events DataFrame:")
        rprint(events_df.head())
        fixations_df = pd.read_csv(fixations_path)

        # Convert timestamp columns
        events_df["timestamp [ns]"] = events_df["timestamp [ns]"].astype("int64")
        fixations_df["start timestamp [ns]"] = fixations_df[
            "start timestamp [ns]"
        ].astype("int64")
        fixations_df["end timestamp [ns]"] = fixations_df["end timestamp [ns]"].astype(
            "int64"
        )

        # Use all events as interval boundaries (consecutive pairs)
        events_sorted = events_df.sort_values("timestamp [ns]").reset_index(drop=True)
        starts = events_sorted.iloc[:-1].reset_index(drop=True)
        ends = events_sorted.iloc[1:].reset_index(drop=True)

        interval_fixations = []

        for i in range(min(len(starts), len(ends))):
            start_time = starts.loc[i, "timestamp [ns]"]
            end_time = ends.loc[i, "timestamp [ns]"]
            label = f"{starts.loc[i, 'name']} - {ends.loc[i, 'name']}"

            matches = fixations_df[
                (fixations_df["start timestamp [ns]"] >= start_time)
                & (fixations_df["end timestamp [ns]"] <= end_time)
            ].copy()
            matches["interval_label"] = label
            interval_fixations.append(matches)

        if not interval_fixations:
            rprint("No fixations found within event intervals.")
            return pd.DataFrame()

        all_fixations = pd.concat(interval_fixations, ignore_index=True)

        # Compute stats
        stats = (
            all_fixations.groupby("interval_label")["duration [ms]"]
            .agg(count="count", mean="mean", std="std", min="min", max="max")
            .reset_index()
        )

        return stats


if __name__ == "__main__":
    if len(sys.argv) != 2:
        rprint("Usage: python script.py <zip_path>")
        sys.exit(1)
    stats = avg_fix_duration(sys.argv[1])
    if not stats.empty:
        rprint("Average Fixation Duration Statistics:")
        rprint(stats)
    else:
        print("No fixation data available.")
