import os
import glob
import pyxdf
import numpy as np
import soundfile as sf

def extract_audio_from_xdf(xdf_file_path):
    # Load XDF file
    streams, header = pyxdf.load_xdf(xdf_file_path)

    # Search for audio stream
    audio_stream = None
    for stream in streams:
        if stream['info']['type'][0] == 'Audio':
            audio_stream = stream
            break

    if audio_stream is None:
        print("Audio stream not found in the XDF file:", xdf_file_path)
        return None, None

    # Extract timestamps and audio data
    timestamps = audio_stream['time_stamps']
    audio_data = audio_stream['time_series']

    return timestamps, audio_data

def save_audio_to_wav(timestamps, audio_data, output_file):
    # Assuming audio_data is in 16-bit integer format and sample rate is 48000 Hz
    sf.write(output_file, audio_data, 48000, subtype='PCM_16')

if __name__ == "__main__":
    script_dir = os.path.dirname(os.path.abspath(__file__))
    input_directory = script_dir
    output_directory = script_dir

    xdf_files = glob.glob(os.path.join(input_directory, "*.xdf"))

    if len(xdf_files) == 0:
        print("No XDF files found in the directory:", input_directory)
    else:
        for xdf_file_path in xdf_files:
            output_wav_file = os.path.join(output_directory, os.path.splitext(os.path.basename(xdf_file_path))[0] + ".wav")

            timestamps, audio_data = extract_audio_from_xdf(xdf_file_path)
            if timestamps is not None and audio_data is not None:
                save_audio_to_wav(timestamps, audio_data, output_wav_file)
                print("Audio saved to:", output_wav_file)
