# AGPL-3 Karl Semich 2022
import numpy as np

def create_freq2time(freq_rate, time_rate, freq_count, time_count):
    freqs = np.fft.fftfreq(freq_count)   
    offsets = np.arange(time_count) * freq_rate / time_rate
    mat = np.exp(2j * np.pi * np.outer(freqs, offsets))
    return mat / freq_count # scaled to match numpy convention

def create_time2freq(time_rate, freq_rate, time_count, freq_count):
    if time_count != freq_count:
        raise NotImplementedError("differing input and output samplecounts")
    forward_mat = create_freq2time(freq_rate, time_rate, freq_count, time_count)
    reverse_mat = np.linalg.inv(forward_mat)
    return reverse_mat

def test():
    #np.random.seed(0)
    randvec = np.random.random(16)
    ift16 = create_freq2time(1, 1, 16, 16)
    ft16 = create_time2freq(1, 1, 16, 16)
    randvec2time = randvec @ ift16
    randvec2freq = randvec @ ft16
    randvec2ifft = np.fft.ifft(randvec)
    randvec2fft = np.fft.fft(randvec)
    assert np.allclose(randvec2ifft, randvec2time)
    assert np.allclose(randvec2fft, randvec2freq)
    assert np.allclose(randvec2ifft, np.linalg.solve(ft16.T, randvec))
    assert np.allclose(randvec2fft, np.linalg.solve(ift16.T, randvec))
    
    # sample data at a differing rate
    time_rate = np.random.random() * 2
    freq_rate = 1.0
    freqs = np.fft.fftfreq(len(randvec))
    rescaling_ift = create_freq2time(freq_rate, time_rate, len(randvec), len(randvec))
    rescaling_ft = create_time2freq(time_rate, freq_rate, len(randvec), len(randvec))
    rescaled_time_data = np.array([
        np.mean([
            randvec[freqidx] * np.exp(2j * np.pi * freqs[freqidx] * sampleidx / time_rate)
            for freqidx in range(len(randvec))
        ])
        for sampleidx in range(len(randvec))
    ])
    assert np.allclose(rescaled_time_data, randvec @ rescaling_ift)
    assert np.allclose(rescaled_time_data, np.linalg.solve(rescaling_ft.T, randvec))
    unscaled_freq_data = rescaled_time_data @ rescaling_ft
    unscaled_time_data = unscaled_freq_data @ ift16
    assert np.allclose(unscaled_freq_data, randvec)
    assert np.allclose(unscaled_time_data, randvec2time)
    assert np.allclose(np.linalg.solve(rescaling_ift.T, rescaled_time_data), randvec)

if __name__ == '__main__':
    test()
