import time
import zmq

import numpy as np

from msgpack import loads
from matplotlib import pyplot as plt

SAMPLING_FREQUENCY = 120 # Hz
SECONDS_TO_RECORD = 3

def request_pupil_time(socket):
    """
    from https://github.com/pupil-labs/pupil-helpers/blob/master/python/simple_realtime_time_sync.py
    """
    socket.send_string("t")
    pupil_time = socket.recv()
    return float(pupil_time)


def measure_clock_offset(socket, clock_function):
    """
    from https://github.com/pupil-labs/pupil-helpers/blob/master/python/simple_realtime_time_sync.py
    """
    local_time_before = clock_function()
    pupil_time = request_pupil_time(socket)
    local_time_after = clock_function()

    local_time = (local_time_before + local_time_after) / 2.0
    clock_offset = pupil_time - local_time
    return clock_offset

def measure_clock_offset_stable(socket, clock_function, nsamples=10):
    """
    from https://github.com/pupil-labs/pupil-helpers/blob/master/python/simple_realtime_time_sync.py
    """
    assert nsamples > 0, "Requires at least one sample"
    offsets = [measure_clock_offset(socket, clock_function) for x in range(nsamples)]
    return sum(offsets) / len(offsets)  # mean offset

addr = "127.0.0.1"  # remote ip or localhost
req_port = "50020"  # same as in the pupil remote gui
surface_name = "screen"

context = zmq.Context()
pupil_remote = context.socket(zmq.REQ)
pupil_remote.connect("tcp://{}:{}".format(addr, req_port))

pupil_remote.send_string("SUB_PORT")
sub_port = pupil_remote.recv_string()
print(sub_port)

# open a sub port to listen to pupil
subscriber = context.socket(zmq.SUB)
subscriber.connect(f'tcp://{addr}:{sub_port}')
subscriber.subscribe('gaze.')  # receive all gaze messages

subscriber.setsockopt(zmq.CONFLATE, 1)  # last msg only.

local_clock = time.perf_counter

offset = measure_clock_offset_stable(pupil_remote, clock_function=local_clock)
print(offset)

timestamps = []
frames = 0
start_time = time.time()
while frames < SAMPLING_FREQUENCY * SECONDS_TO_RECORD:
    inner_start_time = time.time()
    topic, payload = subscriber.recv_multipart() # Does not appear to take full socket queue or most recent frame
    gaze_data = loads(payload, raw=False)
    this_timestamp = gaze_data["timestamp"]-offset
    print(f'packet timestamp: {this_timestamp}')
    timestamps.append(this_timestamp)
    frames += 1
    inner_end_time = time.time()
    sleep_time = (1/SAMPLING_FREQUENCY) - (inner_end_time - inner_start_time)
    time.sleep(0 if sleep_time < 0 else sleep_time)
    
end_time = time.time()

time_elapsed = end_time - start_time
timestamps_elapsed = timestamps[-1] - timestamps[0]
# IF recv_multiport() grabbed most recent frame, time_elapsed would roughly equal timestamps_elapsed. However,
# 

print(time_elapsed)
print(timestamps_elapsed)

# This graph should show monotonically increasing timestamps, but instead shows sporatic
# decreases in time (though the trend is a consistent increase) from one timestamp to the next. 
# Should I be sorting frames by timestamp?
plt.plot(timestamps)
plt.ylabel('timestamp value')
plt.xlabel('sample')
plt.show()
