import zmq
from msgpack import unpackb, packb
import numpy as np
import cv2
from multiprocessing import Queue, Process


def notify(notification, req):
    """Sends ``notification`` to Pupil Remote"""
    topic = 'notify.' + notification['subject']
    payload = packb(notification, use_bin_type=True)
    req.send_string(topic, flags=zmq.SNDMORE)
    req.send(payload)
    return req.recv_string()

def recv_from_sub(sub):
    '''Recv a message with topic, payload.
    Topic is a utf-8 encoded string. Returned as unicode object.
    Payload is a msgpack serialized dict. Returned as a python dict.
    Any addional message frames will be added as a list
    in the payload dict with key: '__raw_data__' .
    '''
    topic = sub.recv_string()
    payload = unpackb(sub.recv(), encoding='utf-8')
    extra_frames = []
    while sub.get(zmq.RCVMORE):
        extra_frames.append(sub.recv())
    if extra_frames:
        payload['__raw_data__'] = extra_frames
    return topic, payload

def denormalize(pos, size, flip_y=False):
    '''Denormalize gaze coordinats
    '''
    width, height = size
    x = pos[0]
    y = pos[1]
    x *= width
    if flip_y:
        y = 1 - y
    y *= height
    return int(x), int(y)



def network_process(image_queue, gaze_queue):
    """Calls all networking procedurs and stores images and gazes into queue
    
    Arguments:
        image_queue {Queue} -- [queue to store images]
        gaze_queue {Queue} -- [queue to store gaze coordinates]
    """
    context = zmq.Context()
    # open a req port to talk to pupil
    addr = '127.0.0.1'  # remote ip or localhost
    req_port = "50020"  # same as in the pupil remote gui
    req = context.socket(zmq.REQ)
    req.connect("tcp://{}:{}".format(addr, req_port))
    
    req.send_string('T 0.0')
    print(req.recv_string())

    # ask for the sub port
    req.send_string('SUB_PORT')
    sub_port = req.recv_string()
    # Start frame publisher with format BGR
    notify({'subject': 'start_plugin', 'name': 'Frame_Publisher', 'args': {'format': 'bgr'}}, req)

    # open a sub port to listen to pupil
    sub = context.socket(zmq.SUB)
    sub.connect("tcp://{}:{}".format(addr, sub_port))
    

    # set subscriptions to topics
    # recv just pupil/gaze/notifications
    sub.setsockopt_string(zmq.SUBSCRIBE, 'frame.world')
    sub.setsockopt_string(zmq.SUBSCRIBE, 'gaze.3d.01')
    sizes = None
    while True:
        topic, msg = recv_from_sub(sub)
        if topic == 'frame.world':
            img = np.frombuffer(msg['__raw_data__'][0], dtype=np.uint8).reshape(msg['height'], msg['width'], 3)
            timestamp = msg['timestamp']
            sizes = (msg['height'], msg['width'])
            image_queue.put((img, timestamp))
        elif topic == 'gaze.3d.01.':
            if sizes is not None:
                gaze_coords = msg['norm_pos']
                gaze_coords = denormalize(gaze_coords, sizes, flip_y=True)
                timestamp = msg['timestamp']
                gaze_queue.put((gaze_coords, timestamp))

def matching_process(image_queue, gaze_queue, output_images_queue):
    """Processes images and gazes from queue, matches them by timestep and 
    outputs them onto the output_images_queue
    
    Arguments:
        image_queue {Queue} -- [queue to store images]
        gaze_queue {Queue} -- [queue to store gaze coordinates]
        output_images_queue {Queue} -- [queue to store matched images and coordinates]
    """
    tolerance = 0.01
    while True:
        (img, img_timestamp) = image_queue.get()
        (gaze_coords, gaze_timestamp) = gaze_queue.get()
        print(img_timestamp, gaze_timestamp)
        while gaze_timestamp > img_timestamp - tolerance:
            (img, img_timestamp) = image_queue.get()
        while img_timestamp > gaze_timestamp - tolerance:
            (gaze_coords, gaze_timestamp) = gaze_queue.get()
        if abs(gaze_timestamp - img_timestamp) < tolerance:
            print('hello')
            cv2.circle(img, gaze_coords, 5, (0, 255, 0))
            output_images_queue.put((img, img_timestamp))


    

def output_process(output_images_queue):
    pass

image_queue = Queue(100)
gaze_queue = Queue(100)
output_images_queue = Queue(100)

if __name__ == '__main__':
    network_p = Process(target=network_process, args=(image_queue, gaze_queue))
    matching_p = Process(target=matching_process, args=(image_queue, gaze_queue, output_images_queue))
    
    network_p.start()
    matching_p.start()

    while True: 
        (img, img_timestamp) = output_images_queue.get()
        cv2.imshow('image',img)
        cv2.waitKey(0)
    
