import logging
import queue
import socket
import sys
import threading
import time
from typing import Any, Dict, Optional, Tuple


import msgpack as serializer
import zmq

logging.basicConfig(
    format="%(message)s",
    datefmt="[%X]",
    level=logging.DEBUG,
)


def check_capture_exists(ip_address: str, port: int) -> None:
    """Checks if Pupil Capture instance exists at the given IP address and port."""
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
        if sock.connect_ex((ip_address, port)) != 0:
            logging.error("Cannot find Pupil Capture")
            sys.exit()
        logging.info("Found Pupil Capture")


def setup_pupil_remote_connection(
    ip_address: str, port: int
) -> Tuple[zmq.Socket, zmq.Socket, zmq.Socket]:
    """Sets up and returns ZMQ REQ, PUB, and SUB sockets to communicate with Pupil Capture."""
    ctx = zmq.Context.instance()
    pupil_remote = ctx.socket(zmq.REQ)
    pupil_remote.connect(f"tcp://{ip_address}:{port}")

    pupil_remote.send_string("PUB_PORT")
    pub_port = pupil_remote.recv_string()
    pub_socket = ctx.socket(zmq.PUB)
    pub_socket.connect(f"tcp://{ip_address}:{pub_port}")

    pupil_remote.send_string("SUB_PORT")
    sub_port = pupil_remote.recv_string()
    sub_socket = ctx.socket(zmq.SUB)
    sub_socket.connect(f"tcp://{ip_address}:{sub_port}")
    # sub_socket.setsockopt_string(zmq.SUBSCRIBE, "")
    topics = ["notify.recording", "notify.calibration", "notify.validation"]
    for topic in topics:
        sub_socket.setsockopt_string(zmq.SUBSCRIBE, topic)

    return pupil_remote, pub_socket, sub_socket


def notify(pub_socket: zmq.Socket, notification: Dict[str, Any]) -> None:
    """Sends a notification to Pupil Remote."""
    topic = "notify." + notification["subject"]
    payload = serializer.dumps(notification, use_bin_type=True)
    pub_socket.send_string(topic, zmq.SNDMORE)
    pub_socket.send(payload)


def listen_for_messages(
    sub_socket: zmq.Socket, message_queue: "queue.Queue[Tuple[str, Dict[str, Any]]]"
) -> None:
    """Continuously listens for messages from Pupil Capture and stores them in a queue."""
    while True:
        try:
            parts = sub_socket.recv_multipart()
            if len(parts) < 2:
                logging.debug("Received message with less than 2 parts, skipping")
                continue
            topic = parts[0].decode()
            # If there are more than 2 parts, this concatenates the remaining parts as the payload
            payload = b"".join(parts[1:])
            try:
                msg = serializer.loads(payload, raw=False)
                message_queue.put((topic, msg))
                logging.debug(f"{topic}: {msg}")
            except Exception:
                logging.debug(f"Error decoding message: {topic}")
        except zmq.ZMQError:
            break


def wait_for_specific_message(
    message_queue: "queue.Queue[Tuple[str, Dict[str, Any]]]",
    expected_subject: str,
    timeout: int = 5000,
) -> Optional[Dict[str, Any]]:
    """Waits for a message with the expected subject from Pupil Capture within the specified timeout."""
    end_time = time.time() + timeout / 1000.0
    while time.time() < end_time:
        try:
            _, msg = message_queue.get(timeout=end_time - time.time())
            if msg.get("subject") == expected_subject:
                return msg
        except queue.Empty:
            break
    return None


def main(ip_address: str = "127.0.0.1", port: int = 50020) -> None:
    check_capture_exists(ip_address, port)
    pupil_remote, pub_socket, sub_socket = setup_pupil_remote_connection(
        ip_address, port
    )

    message_queue = queue.Queue()
    listener_thread = threading.Thread(
        target=listen_for_messages, args=(sub_socket, message_queue), daemon=True
    )
    listener_thread.start()

    msgs = [
        {"subject": "recording.should_start"},
        {"subject": "calibration.should_start"},
        {"subject": "validation.should_start"},
        {"subject": "recording.should_stop"},
    ]

    # Start recording
    notify(pub_socket, msgs[0])
    if wait_for_specific_message(message_queue, "recording.started", 5000):
        logging.info("Recording started successfully.")
    else:
        logging.error("Failed to start recording.")



    # Stop recording
    notify(pub_socket, msgs[3])
    if wait_for_specific_message(message_queue, "recording.stopped", 10000):
        logging.info("Recording stopped successfully.")
    else:
        logging.error("Failed to stop recording.")

    # Clean up
    sub_socket.close()
    pub_socket.close()
    pupil_remote.close()
    zmq.Context.instance().term()


if __name__ == "__main__":
    main()
