import subprocess
import matplotlib.pyplot as plt
import os
import statistics
import json
import csv
try:
    from tqdm import tqdm
    HAS_TQDM = True
except ImportError:
    HAS_TQDM = False
from tkinter import Tk, filedialog

# ===== CONFIGURATION =====
VIDEO_FILE = None                   # Will be selected via file explorer
TOLERANCE = 1e-6                    # Tolerance for duplicate frame detection
DROP_THRESHOLD_MULTIPLIER = 1.5     # Threshold for detecting dropped frames
SAVE_PLOT = True                    # Save plot to PNG instead of just showing
EXPORT_JSON = True                  # Export results to JSON
EXPORT_CSV = True                   # Export results to CSV
SHOW_PROGRESS = True                # Show progress bar for large files
# ==========================

def select_video_file():
    root = Tk()
    root.withdraw()  # Hide main window
    file_path = filedialog.askopenfilename(
        title="Select Video File",
        filetypes=[("Video files", "*.mp4 *.mkv *.avi *.mov"), ("All files", "*.*")]
    )
    root.destroy()
    return file_path

def get_video_info(video_file):
    # Extract r_frame_rate, avg_frame_rate, and duration
    result = subprocess.run(
        [
            "ffprobe", "-v", "0",
            "-select_streams", "v:0",
            "-show_entries", "stream=r_frame_rate,avg_frame_rate,duration",
            "-of", "default=noprint_wrappers=1:nokey=1",
            video_file
        ],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True
    )
    if result.returncode != 0:
        return None, None, None, "Unknown"

    lines = result.stdout.strip().splitlines()
    if len(lines) < 3:
        return None, None, None, "Unknown"

    def parse_fps(fps_str):
        try:
            num, den = fps_str.split('/')
            return float(num) / float(den)
        except Exception:
            return None

    r_fps = parse_fps(lines[0])
    avg_fps = parse_fps(lines[1])
    try:
        duration = float(lines[2])
    except Exception:
        duration = None

    mode = "Unknown"
    if r_fps and avg_fps:
        if abs(r_fps - avg_fps) < 0.01:
            mode = "CFR"
        else:
            mode = "VFR"

    return r_fps, avg_fps, duration, mode

def extract_timestamps(video_file):
    print("[✓] Extracting timestamps using ffprobe...")
    result = subprocess.run(
        [
            "ffprobe", "-v", "error",
            "-select_streams", "v:0",
            "-show_entries", "frame=pts_time",
            "-of", "csv=p=0",
            video_file
        ],
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        text=True
    )

    if result.returncode != 0:
        print("❌ ffprobe failed:", result.stderr)
        return []

    pts_list = [float(line.strip()) for line in result.stdout.strip().splitlines() if line.strip()]
    return pts_list

def analyze_timestamps(pts_list):
    if len(pts_list) < 2:
        print("❌ Not enough frames to analyze.")
        return [], 0, [], [], []

    print(f"[✓] Analyzing {len(pts_list)} frames...")

    intervals = []
    iterable = range(len(pts_list) - 1)
    if SHOW_PROGRESS and HAS_TQDM and len(pts_list) > 1000:
        iterable = tqdm(iterable, desc="Analyzing intervals")

    for i in iterable:
        intervals.append(pts_list[i+1] - pts_list[i])

    valid_intervals = [iv for iv in intervals if iv > 0]
    if not valid_intervals:
        print("❌ No valid frame intervals found.")
        return intervals, 0, [], [], []

    expected_interval = statistics.median(valid_intervals)
    drop_threshold = expected_interval * DROP_THRESHOLD_MULTIPLIER

    # Detect paused (duplicate) frames
    paused = [(i, pts_list[i]) for i in range(len(intervals)) if abs(intervals[i]) < TOLERANCE]

    # Detect dropped frames
    dropped = [(i, intervals[i]) for i in range(len(intervals)) if intervals[i] > drop_threshold]

    # Detect out-of-order PTS
    out_of_order = [(i, pts_list[i], pts_list[i+1]) for i in range(len(pts_list) - 1) if pts_list[i+1] < pts_list[i]]

    return intervals, expected_interval, paused, dropped, out_of_order

def plot_intervals(intervals, expected_interval, drop_threshold):
    plt.figure(figsize=(12, 5))
    plt.plot(intervals, label='Frame Interval (s)', marker='o', linestyle='-')
    plt.axhline(expected_interval, color='green', linestyle='--', label='Expected Interval')
    plt.axhline(drop_threshold, color='red', linestyle='--', label='Drop Threshold')
    plt.xlabel("Frame Number")
    plt.ylabel("Interval (s)")
    plt.title("Frame Interval Analysis")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()

    if SAVE_PLOT:
        plt.savefig("frame_intervals.png")
        print("[✓] Plot saved as frame_intervals.png")
    else:
        plt.show()

def export_results(pts_list, expected_interval, paused, dropped, out_of_order, r_fps, avg_fps, mode, effective_fps, duration, expected_duration):
    results = {
        "total_frames": len(pts_list),
        "r_frame_rate": r_fps,
        "avg_frame_rate": avg_fps,
        "effective_fps": effective_fps,
        "mode": mode,
        "expected_interval": expected_interval,
        "video_duration": duration,
        "expected_duration": expected_duration,
        "duration_difference": duration - expected_duration if duration and expected_duration else None,
        "paused_frames": paused,
        "dropped_frames": dropped,
        "out_of_order": out_of_order
    }

    if EXPORT_JSON:
        with open("analysis.json", "w") as f:
            json.dump(results, f, indent=2)
        print("[✓] Results exported to analysis.json")

    if EXPORT_CSV:
        with open("analysis.csv", "w", newline="") as f:
            writer = csv.writer(f)
            writer.writerow(["Type", "FrameIndex", "Value1", "Value2"])
            for idx, ts in paused:
                writer.writerow(["Paused", idx, ts, ""])
            for idx, gap in dropped:
                writer.writerow(["Dropped", idx, gap, ""])
            for idx, prev_pts, curr_pts in out_of_order:
                writer.writerow(["OutOfOrder", idx, prev_pts, curr_pts])
        print("[✓] Results exported to analysis.csv")

def main():
    global VIDEO_FILE
    VIDEO_FILE = select_video_file()
    if not VIDEO_FILE:
        print("❌ No file selected.")
        return

    if not os.path.isfile(VIDEO_FILE):
        print(f"❌ File not found: {VIDEO_FILE}")
        return

    r_fps, avg_fps, duration, mode = get_video_info(VIDEO_FILE)

    pts_list = extract_timestamps(VIDEO_FILE)
    if not pts_list:
        print("❌ No timestamps extracted.")
        return

    intervals, expected_interval, paused, dropped, out_of_order = analyze_timestamps(pts_list)
    effective_fps = 1 / expected_interval if expected_interval > 0 else None
    expected_duration = len(pts_list) / effective_fps if effective_fps else None

    print("\n📋 SUMMARY")
    print(f"Video file: {VIDEO_FILE}")
    if r_fps:
        print(f"Nominal FPS (r_frame_rate): {r_fps:.3f}")
    if avg_fps:
        print(f"Average FPS (avg_frame_rate): {avg_fps:.3f}")
    if effective_fps:
        print(f"Effective FPS (calculated): {effective_fps:.3f}")
    print(f"Detected mode: {mode}")
    if duration:
        print(f"Video duration (s): {duration:.3f}")
    if expected_duration:
        print(f"Expected duration from frames (s): {expected_duration:.3f}")
        if duration:
            print(f"Duration difference (s): {duration - expected_duration:.6f}")
    print(f"Total frames: {len(pts_list)}")
    print(f"Expected frame interval (median): {expected_interval:.6f} s")
    print(f"Paused (duplicate PTS) frames: {len(paused)}")
    print(f"Dropped/irregular frames: {len(dropped)}")
    print(f"Out-of-order PTS frames: {len(out_of_order)}")

    if paused:
        print("\n⏸ Paused frames (same PTS as previous):")
        for idx, ts in paused[:10]:
            print(f"  Frame {idx} and {idx+1} at {ts:.6f} s")
        if len(paused) > 10:
            print("  ...")

    if dropped:
        print("\n🕳 Dropped/irregular frame intervals:")
        for idx, gap in dropped[:10]:
            print(f"  Frame {idx} → {idx+1} gap: {gap:.6f} s")
        if len(dropped) > 10:
            print("  ...")

    if out_of_order:
        print("\n🔁 Out-of-order PTS detected:")
        for i, prev_pts, curr_pts in out_of_order[:10]:
            print(f"  Frame {i} PTS: {prev_pts:.6f} > Frame {i+1} PTS: {curr_pts:.6f}")
        if len(out_of_order) > 10:
            print("  ...")

    plot_intervals(intervals, expected_interval, expected_interval * DROP_THRESHOLD_MULTIPLIER)
    export_results(pts_list, expected_interval, paused, dropped, out_of_order, r_fps, avg_fps, mode, effective_fps, duration, expected_duration)

if __name__ == "__main__":
    main()

