class Connect2PupilCore(object):
    def __init__(
        self,
        ip_address=get_computer_settings()["eyetracker_ip_address"],
        port=get_computer_settings()["eyetracker_port_number"],
        surface_name="Surface 1"  # Default surface name
    ):
        self.ip_address = ip_address
        self.port = port
        self.surface_name = surface_name
        self.last_gaze_position = None  # Store last valid gaze position

        # Create a ZMQ context
        self.ctx = zmq.Context.instance()

        # Request Socket (REQ) to ask Pupil Core for required ports
        self.socket = self.ctx.socket(zmq.REQ)
        self.socket.connect(f"tcp://{self.ip_address}:{self.port}")

        # Request PUB_PORT (for sending commands)
        self.socket.send_string("PUB_PORT")
        self.pub_port = self.socket.recv_string()

        # Request SUB_PORT (for receiving data streams)
        self.socket.send_string("SUB_PORT")
        self.sub_port = self.socket.recv_string()

        # Set up PUB socket (for sending commands to Pupil Core)
        self.pub_socket = self.ctx.socket(zmq.PUB)
        self.pub_socket.connect(f"tcp://{self.ip_address}:{self.pub_port}")

        # Set up SUB socket (for receiving gaze/surface data)
        self.sub_socket = self.ctx.socket(zmq.SUB)
        self.sub_socket.connect(f"tcp://{self.ip_address}:{self.sub_port}")
        self.sub_socket.setsockopt_string(zmq.SUBSCRIBE, "surface")  # Subscribe to surface tracking

        # Set up SUB socket (for calibration)
        self.sub_socket_c = self.ctx.socket(zmq.SUB)
        self.sub_socket_c.connect(f"tcp://{self.ip_address}:{self.sub_port}")
        self.sub_socket_c.setsockopt_string(zmq.SUBSCRIBE, "notify.calibration")

    def request_pupil_version(self):
        self.socket.send_string("v")
        pupil_version = self.socket.recv_string()
        return str(pupil_version)

    def evaluate_time_sync(self):
        # Measure clock offset once
        offset = self.measure_clock_offset()
        print(f"Clock offset (1 measurement): {offset} seconds")

        # Measure clock offset more reliably to account for network
        # latency variance (multiple measurements)
        number_of_measurements = 10
        stable_offset_mean = self.measure_clock_offset_stable(
            nsamples=number_of_measurements
        )
        print(
            f"Mean clock offset ({number_of_measurements} measurements): "
            f"{stable_offset_mean} seconds"
        )

        # 5. Infer pupil clock time from "local" clock measurement
        local_time = time.perf_counter()
        pupil_time_calculated_locally = local_time + stable_offset_mean
        print(f"Local time: {local_time}")
        print(f"Pupil time (calculated locally): {pupil_time_calculated_locally}")

        return stable_offset_mean

    def request_pupil_time(self):
        self.socket.send_string("t")
        pupil_time = self.socket.recv()
        return float(pupil_time)

    def sync_pupil_time(self):
        self.socket.send_string("T {}".format(str(time.perf_counter())))
        new_pupil_time = self.socket.recv_string()
        return str(new_pupil_time)

    def measure_clock_offset(self):
        local_time_before = time.perf_counter()
        pupil_time = self.request_pupil_time()
        local_time_after = time.perf_counter()

        local_time = (local_time_before + local_time_after) / 2.0
        clock_offset = pupil_time - local_time
        return clock_offset

    def measure_clock_offset_stable(self, nsamples=10):
        assert nsamples > 0, "Requires at least one sample"
        offsets = [self.measure_clock_offset() for x in range(nsamples)]
        return sum(offsets) / len(offsets)  # mean offset

    def send_time_stamp(self, label, time):
        minimal_trigger = self.new_trigger(str(label), time, 1)
        print(minimal_trigger)
        self.send_trigger(minimal_trigger)

    def send_trigger(self, trigger):
        """Sends annotation via PUB port"""
        payload = serializer.dumps(trigger, use_bin_type=True)
        self.pub_socket.send_string(trigger["topic"], flags=zmq.SNDMORE)
        self.pub_socket.send(payload)

    def new_trigger(self, label, time, duration):
        """Creates a new trigger/annotation to send to Pupil Capture"""
        return {
            "topic": "annotation",
            "label": label,
            "timestamp": time,
            "duration": duration,
            "custom_key": "t",
            "added_in_capture": True
        }

    def start_recording(self):
        self.socket.send_string("R")  # 'R' starts recording
        response = self.socket.recv_string()
        return response

    def stop_recording(self):
        self.socket.send_string("r")  # 'r' stops recording
        response = self.socket.recv_string()
        print(response)

    def get_gaze_on_surface(self, duration=5.0):
        """Collect gaze points from the specified surface in real-time."""
        gaze_points = []
        start_time = time.time()

        while time.time() - start_time < duration:
            try:
                topic = self.sub_socket.recv_string()
                msg = self.sub_socket.recv()
                surfaces = serializer.loads(msg)

                # Filter by surface name
                if surfaces.get("name") == self.surface_name:
                    gaze_positions = surfaces.get("gaze_on_surfaces", [])
                   
                    for gaze_pos in gaze_positions:
                        norm_gp_x, norm_gp_y = gaze_pos["norm_pos"]
                       
                        # Only accept gaze points inside the surface bounds
                        if 0 <= norm_gp_x <= 1 and 0 <= norm_gp_y <= 1:
                            print(f"Gaze on {self.surface_name}: ({norm_gp_x}, {norm_gp_y})")
                            gaze_points.append([norm_gp_x, norm_gp_y])
                            self.last_gaze_position = (norm_gp_x, norm_gp_y)  # Store last valid position
            except:
                pass
           
            # zmq.Again:
            #     time.sleep(0.01)  # Avoid high CPU usage

        print(f"Collected {len(gaze_points)} gaze points.")
        return np.array(gaze_points)

    def check_fixation_accuracy(self, gaze_points, fixation_point=(0.5, 0.5), threshold=0.05):
        """Checks if gaze points are close to the fixation point."""
        if len(gaze_points) == 0:
            return False, None  # No valid gaze data collected

        mean_gaze = np.mean(gaze_points, axis=0)
        distance = np.linalg.norm(np.array(mean_gaze) - np.array(fixation_point))

        return distance < threshold, mean_gaze.tolist()

    def log_slippage_event(self, session_id, mean_gaze):
        """Logs timestamps where slippage occurs for post hoc adjustments."""

        save_dir = SessionManager.get_from_session('savedir')
        slippage_file = os.path.join(save_dir, 'slippage_log.json')

        slippage_entry = {
            "timestamp": time.time(),
            "mean_gaze": mean_gaze
        }

        SessionManager.save_to_file(slippage_file, slippage_entry)

    def start_calibration(self):
        """Start the Pupil Core 2D Calibration process before recording."""
        # Use 2D calibration (instead of default 3D)
        gazer_class_name = "Gazer2D"
        notification = {
            "subject": "start_plugin",
            "name": "ExternalCalibrationChoreography",
            "args": {"selected_gazer_class_name": gazer_class_name},
        }
        self._send_notification(notification)

        # Clear old calibration messages
        self._clear_socket_queue()

        # Start Calibration
        self.socket.send_string("C")  # "C" starts calibration
        self.socket.recv_string()  # Required for ZMQ REQ-REP
        self.wait_for_calibration_notification("started")
        print("Calibration started!")

        # Perform calibration using pre-defined markers
        self.perform_choreography()

        # Stop Calibration
        self.socket.send_string("c")  # "c" stops calibration
        self.socket.recv_string()
        print("Calibration stopped!")

        # Wait for success/failure message
        feedback = self.wait_for_calibration_notification("successful", "failed")
        if feedback.get("subject").endswith("failed"):
            print(f"Calibration failed: {feedback.get('reason')}")
            return False
        print("Calibration successful!")
        return True

    def wait_for_calibration_notification(self, *topic_suffixes):
        while True:
            topic, notification = self._recv_notification()
            if any(topic.endswith(suffix) for suffix in topic_suffixes):
                return notification
            else:
                print(f"Ignoring notification: {topic}")

    def perform_choreography(self):
        """Move through calibration markers in a task-relevant layout."""
        locations = [
            ([0.0, 0.8], "top left"),
            ([0.8, 0.8], "top right"),
            ([0.8, 0.0], "bottom right"),
            ([0.0, 0.0], "bottom left"),
            ([0.5, 0.5], "centre"),
        ]

        duration_per_location_seconds = 1.0
        for location_coord, location_human in locations:
            self._instruct_subject(location_human, duration_per_location_seconds)
            ref_data = []
            for _ in self._timer(duration_per_location_seconds):
                # Add reference data for both eyes. Coordinates can differ between the two.
                ref_data.append(
                    {
                        "norm_pos": location_coord,
                        "timestamp": time.monotonic(),
                        "eye_id": 0,  # right eye
                    }
                )
                ref_data.append(
                    {
                        "norm_pos": location_coord,
                        "timestamp": time.monotonic(),
                        "eye_id": 1,  # left eye
                    }
                )
            self._send_notification(
                {"subject": "calibration.add_ref_data", "ref_data": ref_data},
            )

    def _instruct_subject(self, target_location_humand_description, duration_seconds):
        input(
            f"Look to the {target_location_humand_description}, hit enter, and keep looking"
            f" at the target location for {duration_seconds} seconds"
        )

    def _timer(self, duration_seconds=1.0, sampling_rate_hz=30):
        """Returns control at a fixed rate for `duration_seconds`"""
        num_samples = int(duration_seconds * sampling_rate_hz)
        duration_between_samples = duration_seconds / sampling_rate_hz
        for i in range(num_samples):
            yield
            time.sleep(duration_between_samples)

    def _send_notification(self, notification):
        """Send a notification to Pupil Core."""
        topic = "notify." + notification["subject"]
        payload = serializer.dumps(notification, use_bin_type=True)
        self.socket.send_string(topic, flags=zmq.SNDMORE)
        self.socket.send(payload)
        return self.socket.recv_string()

    def _recv_notification(self):
        """Receives a notification from Pupil Remote"""
        topic = self.sub_socket_c.recv_string()
        payload = self.sub_socket_c.recv()
        notification = serializer.unpackb(payload)
        return topic, notification

    def _clear_socket_queue(self):
        """Clear any old messages before starting a new calibration."""
        while self.sub_socket_c.get(zmq.EVENTS) & zmq.POLLIN:
            try:
                self.sub_socket_c.recv(zmq.NOBLOCK)
            except zmq.ZMQError:
                break

C2PC = Connect2PupilCore()
C2PC.start_calibration()