import subprocess
import matplotlib.pyplot as plt
import os
import json
import tkinter as tk
from tkinter import filedialog
import statistics

# ===== CONFIGURATION =====
DROP_THRESHOLD_MULTIPLIER = 1.5     # Threshold multiplier for detecting dropped frames
# TOLERANCE will be set automatically based on FPS
# ==========================

def select_file():
    root = tk.Tk()
    root.withdraw()
    file_path = filedialog.askopenfilename(
        title="Select a video file",
        filetypes=[("Video files", "*.mp4 *.mkv *.avi *.mov *.flv *.wmv *.webm"), ("All files", "*.*")]
    )
    return file_path

def ffprobe_timestamps(video_file):
    result = subprocess.run(
        [
            "ffprobe", "-v", "error",
            "-select_streams", "v:0",
            "-show_entries", "frame=pts_time",
            "-of", "csv=p=0",
            video_file
        ],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True
    )
    if result.returncode != 0:
        print("❌ ffprobe failed:", result.stderr)
        return []
    return [float(line.strip()) for line in result.stdout.strip().splitlines() if line.strip()]

def ffprobe_nominal(video_file):
    result = subprocess.run(
        [
            "ffprobe", "-v", "error",
            "-select_streams", "v:0",
            "-show_entries", "stream=avg_frame_rate,r_frame_rate,nb_frames",
            "-show_entries", "format=duration",
            "-of", "json",
            video_file
        ],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True
    )
    if result.returncode != 0:
        print("❌ ffprobe (fps) failed:", result.stderr)
        return None, None, {}

    data = json.loads(result.stdout)
    stream = data["streams"][0]
    duration = float(data["format"]["duration"]) if "duration" in data["format"] else None

    def fps_from_fraction(frac):
        if frac and frac != "0/0":
            num, den = frac.split("/")
            return float(num) / float(den)
        return None

    candidates = {
        "avg_frame_rate": fps_from_fraction(stream.get("avg_frame_rate")),
        "r_frame_rate": fps_from_fraction(stream.get("r_frame_rate")),
        "nb_frames": int(stream["nb_frames"]) if stream.get("nb_frames") else None,
        "duration": duration,
    }

    chosen = None
    if candidates["avg_frame_rate"]:
        chosen = candidates["avg_frame_rate"]
    elif candidates["r_frame_rate"]:
        chosen = candidates["r_frame_rate"]
    elif candidates["nb_frames"] and candidates["duration"]:
        chosen = candidates["nb_frames"] / candidates["duration"]

    return chosen, (1.0 / chosen if chosen else None), candidates

def auto_tolerance(fps):
    """Automatically set TOLERANCE based on FPS."""
    if fps <= 30:
        return 1e-5
    elif fps <= 60:
        return 5e-6
    elif fps <= 120:
        return 1e-6
    else:
        return 1e-7

def analyze(pts_list, expected_interval, TOLERANCE):
    if expected_interval is None:
        print("❌ Cannot analyze intervals without expected interval.")
        return None

    intervals = [pts_list[i+1] - pts_list[i] for i in range(len(pts_list) - 1)]
    drop_threshold = expected_interval * DROP_THRESHOLD_MULTIPLIER
    early_threshold = expected_interval * 0.5

    paused = [(i, pts_list[i]) for i in range(len(intervals)) if abs(intervals[i]) < TOLERANCE]
    dropped = [(i, intervals[i]) for i in range(len(intervals)) if intervals[i] > drop_threshold]
    early = [(i, intervals[i]) for i in range(len(intervals)) if intervals[i] < early_threshold and intervals[i] > 0]
    out_of_order = [(i, pts_list[i], pts_list[i+1]) for i in range(len(pts_list) - 1) if pts_list[i+1] < pts_list[i]]

    valid_intervals = [iv for iv in intervals if iv > 0]
    observed_median = statistics.median(valid_intervals) if valid_intervals else None
    min_interval = min(valid_intervals) if valid_intervals else None
    max_interval = max(valid_intervals) if valid_intervals else None

    return {
        "intervals": intervals,
        "paused": paused,
        "dropped": dropped,
        "early": early,
        "out_of_order": out_of_order,
        "observed_median": observed_median,
        "min_interval": min_interval,
        "max_interval": max_interval
    }

def plot_intervals(intervals, expected_interval, drop_threshold):
    plt.figure(figsize=(12, 5))
    plt.plot(intervals, label='Frame Interval (s)', marker='o', linestyle='-')
    plt.axhline(expected_interval, color='green', linestyle='--', label='Expected Interval')
    plt.axhline(drop_threshold, color='red', linestyle='--', label='Drop Threshold')
    plt.axhline(expected_interval * 0.5, color='blue', linestyle='--', label='Early Threshold')
    plt.xlabel("Frame Number")
    plt.ylabel("Interval (s)")
    plt.title("Frame Interval Analysis")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def main():
    video_file = select_file()
    if not video_file or not os.path.isfile(video_file):
        print("❌ No valid file selected.")
        return

    print(f"[✓] Analyzing file: {video_file}")

    pts_list = ffprobe_timestamps(video_file)
    if not pts_list:
        print("❌ No timestamps extracted.")
        return

    fps, expected_interval, candidates = ffprobe_nominal(video_file)
    if not fps:
        print("❌ Could not determine nominal FPS.")
        return

    TOLERANCE = auto_tolerance(fps)
    print(f"[✓] Auto TOLERANCE set to {TOLERANCE} based on FPS {fps:.2f}")

    print("\n📊 ffprobe FPS candidates:")
    for k, v in candidates.items():
        print(f"  {k}: {v}")
    print(f"Chosen nominal FPS: {fps:.6f}, interval: {expected_interval:.9f} s")

    result = analyze(pts_list, expected_interval, TOLERANCE)
    if not result:
        return

    total_intervals = len(result['intervals'])
    paused_count = len(result['paused'])
    dropped_count = len(result['dropped'])
    early_count = len(result['early'])
    out_of_order_count = len(result['out_of_order'])

    print("\n📋 SUMMARY")
    print(f"Total frames: {len(pts_list)}")
    print(f"Expected interval (chosen): {expected_interval:.9f} s")
    print(f"Observed median interval: {result['observed_median']:.9f} s")
    print(f"Paused (duplicate PTS) frames: {paused_count} ({paused_count/total_intervals*100:.2f}%)")
    print(f"Dropped/irregular frames: {dropped_count} ({dropped_count/total_intervals*100:.2f}%)")
    print(f"Early (fast) frames: {early_count} ({early_count/total_intervals*100:.2f}%)")
    print(f"Out-of-order PTS frames: {out_of_order_count} ({out_of_order_count/total_intervals*100:.2f}%)")
    print(f"Min interval: {result['min_interval']:.9f} s, Max interval: {result['max_interval']:.9f} s")

    if result["early"]:
        print("\n⚡ Early frames (first 10):")
        for idx, gap in result["early"][:10]:
            print(f"  Frame {idx} → {idx+1} gap: {gap:.9f} s")
        if len(result["early"]) > 10:
            print("  ...")

    plot_intervals(result["intervals"], expected_interval, expected_interval * DROP_THRESHOLD_MULTIPLIER)

    # Export JSON
    export = {
        "file": video_file,
        "fps_candidates": candidates,
        "fps_chosen": fps,
        "expected_interval": expected_interval,
        "observed_median_interval": result["observed_median"],
        "paused_frames": {"count": paused_count, "percentage": paused_count/total_intervals*100, "frames": result["paused"]},
        "dropped_frames": {"count": dropped_count, "percentage": dropped_count/total_intervals*100, "frames": result["dropped"]},
        "early_frames": {"count": early_count, "percentage": early_count/total_intervals*100, "frames": result["early"]},
        "out_of_order_frames": {"count": out_of_order_count, "percentage": out_of_order_count/total_intervals*100, "frames": result["out_of_order"]},
        "min_interval": result["min_interval"],
        "max_interval": result["max_interval"],
        "total_intervals": total_intervals
    }
    out_json = os.path.splitext(video_file)[0] + "_frame_analysis.json"
    with open(out_json, "w") as f:
        json.dump(export, f, indent=2)
    print(f"\n[✓] Exported analysis to {out_json}")

if __name__ == "__main__":
    main()
