# AGPL-3 Karl Semich 2022
import numpy as np

# At the moment, wavelets must be:
# - Zero at zero f(0) == 0
# - Even functions f(x) == f(-x)
# - Odd functions when slid by 90 degrees f(0.25 - x) == f(0.25 + x)
# Presently x is passed in the modular range [0,1) but this is arbitrary, just remove it if helpful
COSINE = lambda x: np.cos(2 * np.pi * x)
STEP = lambda x: COSINE(x).round()

def complex_wavelet(wavelet, x):
    return wavelet(x % 1) + wavelet((x - 0.25) % 1) * 1j

def fftfreq(freq_count = None, sample_rate = None,
            min_freq = None, max_freq = None,
            dc_offset = True, complex = True, sample_time = None,
            repetition_time = None, repetition_samples = None,
            freq_sample_rate = None, freq_sample_time = None):
    '''
    Calculates and returns the frequency bins used to convert between the time
    and frequency domains with a discrete Fourier transform.

    With no optional arguments, this function should be equivalent to numpy.fft.fftfreq .

    To specify frequencies in terms of repetitions / sample_count, set
    sample_rate to sample_count. If frequencies were in Hz (cycles/sec), this
    implies that sample_count should be considered to last 1 second, so the
    frequency becomes equal to the cycle count.

    Parameters:
        - freq_count: the number of frequency bins to generate
        - sample_rate: the time-domain sample rate (default: 1 sample/time_unit)
        - min_freq: the minimum frequency the signal contains
        - max_fraq: the maximum frequency the signal contains
        - dc_offset: whether to include a DC offset component (0 Hz)
        - complex: whether to generate negative frequencies
        - sample_time: sample_rate as the duration of a single sample
        - repetition_time: min_freq as a period time, possibly of a subsignal
        - repetition_samples: min_freq as a period size in samples
        - freq_sample_rate: convert to or from a different frequency-domain sample rate
        - freq_sample_time: freq_sample_rate as the duration of a single sample

    Returns a vector of sinusoid time scalings that can be used to perform or
    analyse a discrete Fourier transform.  Multiply this vector by the time-domain
    sample rate to get the frequencies.
    '''
    assert not sample_time or not sample_rate
    assert not freq_sample_time or not freq_sample_rate
    assert (min_freq, repetition_time, repetition_samples).count(None) >= 2
    assert freq_count or min_freq or repetition_time or repetition_samples
    if sample_time is not None:
        sample_rate = 1 / sample_time
    if freq_sample_time is not None:
        freq_sample_rate = 1 / freq_sample_time
    sample_rate = sample_rate or freq_sample_rate or 1
    freq_sample_rate = freq_sample_rate or sample_rate or 1
    if repetition_time:
        min_freq = 1 / repetition_time
    elif repetition_samples:
        min_freq = freq_sample_rate / repetition_samples
    # here freq_count adopts the complex + DC offset value, for calculations
    # it might be clearer to use e.g. sample_count, at least consolidate the cases with lower
    if freq_count is None:
        freq_count = int(np.ceil(freq_sample_rate / min_freq))
        #if (freq_count % 2) == 1 and not complex:
        #    freq_count += 1
    else:
        if not dc_offset:
            freq_count += 1
        if not complex:
            freq_count = int((freq_count - 1) * 2)
        if not min_freq:
            min_freq = freq_sample_rate / freq_count
    if not max_freq:
        #max_freq = freq_sample_rate / 2
        max_freq = freq_count * min_freq / 2
        if freq_count % 2 != 0:
            #max_freq -= freq_sample_rate / (2 * freq_count)
            max_freq -= min_freq / 2
    min_freq /= sample_rate
    max_freq /= sample_rate
    if freq_count % 2 == 0:
        if complex:
            neg_freqs = np.linspace(-max_freq, -min_freq, num=freq_count // 2, endpoint=True)
            pos_freqs = -neg_freqs[:0:-1]
        else:
            pos_freqs = np.linspace(min_freq, max_freq, num=freq_count // 2, endpoint=True)
            neg_freqs = pos_freqs[:0]
    else:
        pos_freqs = np.linspace(min_freq, max_freq, num=freq_count // 2, endpoint=True)
        neg_freqs = -pos_freqs[::-1] if complex else pos_freqs[:0]
        if not complex:
            import warnings
            warnings.warn('An odd-valued complex frequency count was generated for real-valued frequencies. Be sure to pass odd_repetition_samples=True when using the frequencies.', stacklevel=2)
        #assert complex # the doubling of freq_count should prevent odd freq counts
        # if there were a real-domain odd freq count then the matrix functions would need this information
    return np.concatenate([
        np.array([0] if dc_offset else []),
        pos_freqs,
        neg_freqs
    ])

def create_freq2time(time_count = None, freqs = None, wavelet = COSINE, odd_repetition_samples = None):
    '''
    Creates a matrix that will perform an inverse discrete Fourier transform
    when it post-multiplies a vector of complex frequency magnitudes.

    Example
        time_data = spectrum @ create_freq2time(len(spectrum))

    Parameters:
        - time_count: size of the output vector, defaults to the frequency bincount
        - freqs: frequency bins to convert, defaults to traditional FFT frequencies

    Returns:
        - an inverse discrete Fourier matrix of shape (len(freqs), time_count)
    '''
    assert (time_count is not None) or (freqs is not None)
    if freqs is None:
        assert odd_repetition_samples is None
        odd_repetition_samples = (time_count % 2 == 1)
        freqs = fftfreq(time_count)
    is_complex = (freqs < 0).any()
    has_dc_offset = bool(freqs[0] == 0)
    freq_count = len(freqs)
    if not has_dc_offset:
        freq_count += 1
    if not is_complex:
        #has_nyquist = bool(freqs[-1] == 0.5)
        #has_nyquist = True # because when max_freq is changed this information isn't held here
        has_nyquist = (bool(odd_repetition_samples) != has_dc_offset)
        freq_count = (freq_count - 1) * 2 + (not has_nyquist)
    if time_count is None:
        time_count = freq_count
    offsets = np.arange(time_count)
    if is_complex:
        mat = complex_wavelet(wavelet, np.outer(freqs, offsets))
    else:
        head_idx = 1 if has_dc_offset else 0
        tail_idx = len(freqs) - has_nyquist

        # note that a negative frequency produces a negative imaginary part
        # so some things could be simplified here
        complex_freqs = np.concatenate([
            freqs[:tail_idx],
            -freqs[tail_idx:],
            -freqs[head_idx:tail_idx][::-1]
        ])
        complex_mat = complex_wavelet(wavelet, np.outer(complex_freqs, offsets))

        # assuming the input's real and complex components are broken out into separate elements

        # real_in, imag_in
        # multiplied by conjugate pair and summed
        # (real_in_left + imag_in_left j) * (real_in_right + imag_in_right j) +
        #  (real_in_left + imag_in_left j) * (real_in_right - imag_in_right j)
        # (real_in_left + imag_in_left j) * (real_in_right + imag_in_right j + real_in_right - imag_in_right j)
        # (real_in_left + imag_in_left j) * (2 * real_in_right)
        # i must have made a mistake, because the answer is coming out with an imaginary component rather than with phase shift.

        # can consider input of [0,1j,0]
        # real_in_left = real_in_right = 0
        # imag_in_left = 1
        # imag_in_right = 0.25
        # real_out = (real_in_left * real_in_right) - imag_in_left * imag_in_right * 2

        # the nonincluded frequencies are actually negative.
        # so the real output is real_proudct - 2 * imag_product

        # real_out = real_in_left * real_in_right - imag_in_left * imag_in_right
        #import pdb; pdb.set_trace()
        interim_mat = complex_mat.copy()
        interim_mat[head_idx:tail_idx] += interim_mat[len(freqs):][::-1].conj()
        mat = np.stack([
            interim_mat[:len(freqs)].real.T,
            -interim_mat[:len(freqs)].imag.T
        ], axis=-1).reshape(time_count, len(freqs)*2).T
        #mat = mat.reshape(complex_mat.shape[0], len(freqs)*2)
        # for conjugate
        #mat[head_idx*2:tail_idx*2::2] += np.stack([
        #    0, -complex_mat[len(freqs):].imag
        #], axis=-1).reshape(complex_mat.shape[0], (tail_idx-head_idx)*2)

        #mat = complex_wavelet(wavelet, np.outer(np.concatenate([freqs[:tail_idx], -freqs[tail_idx:]]), offsets))

        ## todo: remove test data
        #full_mat = complex_wavelet(wavelet, np.outer(np.concatenate([freqs[:tail_idx], -freqs[tail_idx:], -freqs[head_idx:tail_idx][::-1]]), offsets))
        #time_data = np.random.random(full_mat.shape[1])
        #extended_freq_data = time_data @ np.linalg.inv(full_mat)
        #freq_data = extended_freq_data[:len(freqs)]
        #import pdb; pdb.set_trace()

        #mat[head_idx:tail_idx,:] += complex_wavelet(wavelet, np.outer(-freqs[head_idx:tail_idx], offsets))
        #mat = mat.real
    return mat / freq_count # scaled to match numpy convention

def create_time2freq(time_count = None, freqs = None, wavelet = COSINE, odd_repetition_samples = None):
    '''
    Creates a matrix that will perform a forward discrete Fourier transform
    when it post-multiplies a vector of time series data.

    If time_count is too small or large, the minimal least squares solution
    over all the data passed will be produced.

    This function is equivalent to calling .pinv() on the return value of
    create_freq2time. If the return value is single-use, it is more efficient and
    accurate to use numpy.linalg.lstsq .

    Example
        spectrum = time_data @ create_time2freq(len(time_data))

    Parameters:
        - time_count: size of the input vector, defaults to the frequency bincount
        - freqs: frequency bins to produce, defaults to traditional FFT frequencies

    Returns:
        - a discrete Fourier matrix of shape (time_count, len(freqs))
    '''
    assert (time_count is not None) or (freqs is not None)
    if freqs is None:
        assert odd_repetition_samples is None
        odd_repetition_samples = (time_count % 2 == 1)
        freqs = fftfreq(time_count)
    is_complex = (freqs < 0).any()
    # TODO: provide for DC offsets and nyquists not at index 0 by excluding items in (0,0.5)
    has_dc_offset = bool(freqs[0] == 0)
    if not is_complex:
        #has_nyquist = bool(freqs[-1] == 0.5)
        #has_nyquist = True
        has_nyquist = (bool(odd_repetition_samples) != has_dc_offset)
        head_idx = 1 if has_dc_offset else 0
        tail_idx = len(freqs)-1 if has_nyquist else len(freqs)
        neg_start_idx = len(freqs)
        freqs = np.concatenate([
            freqs[:head_idx],
            freqs[head_idx:tail_idx],
            -freqs[tail_idx:neg_start_idx],
            -freqs[head_idx:tail_idx][::-1]
        ])
    inverse_mat = create_freq2time(time_count, freqs, wavelet, odd_repetition_samples = odd_repetition_samples)
    forward_mat = np.linalg.pinv(inverse_mat)
    if not is_complex:
        forward_mat = np.concatenate([
            forward_mat[:,:tail_idx],
            forward_mat[:,tail_idx:neg_start_idx].conj()
        ], axis=1)
    return forward_mat

def peak_pair_idcs(freq_data, dc_offset=True, sum=True):
    dc_offset = int(dc_offset)
    freq_heights = abs(freq_data) # squares and sums the components
    if sum:
        while len(freq_heights).shape > 1:
            freq_heights = freq_height.sum(axis=0)
    paired_heights = freq_heights[...,dc_offset:-1] + freq_heights[...,dc_offset+1:]
    peak_idx = paired_heights.argmax(axis=-1, keepdims=True) + dc_offset
    return np.concatenate(peak_idx, peak_idx + 1, axis=-1)

def improve_peak(time_data, min_freq, max_freq):
    freqs = fftfreq(time_data.shape[-1], min_freq = min_freq, max_freq = max_freq)
    freq_data = time_data @ np.linalg.inv(create_freq2time(freqs))
    left, right = peak_pair_idcs(freq_data)
    return freq_data[left], freq_data[right]
    

def test():
    np.random.seed(0)

    randvec = np.random.random(16)
    freqs16 = fftfreq(16)
    ift16 = create_freq2time(16, freqs=freqs16)
    ft16 = create_time2freq(16)
    randvec2time = randvec @ ift16
    randvec2freq = randvec @ ft16
    randvec2ifft = np.fft.ifft(randvec)
    randvec2fft = np.fft.fft(randvec)
    assert np.allclose(freqs16, np.fft.fftfreq(16))
    assert np.allclose(randvec2ifft, randvec2time)
    assert np.allclose(randvec2fft, randvec2freq)
    assert np.allclose(randvec2ifft, np.linalg.solve(ft16.T, randvec))
    assert np.allclose(randvec2fft, np.linalg.solve(ift16.T, randvec))
    #randvec15f = randvec[:15] @ create_time2freq(15)
    rfreqs15t = fftfreq(repetition_samples=15, complex=False)
    rfreqs15f = fftfreq(15, complex=False)
    irft_from_15 = create_freq2time(freqs=rfreqs15f)
    rft_from_15 = create_time2freq(15, freqs=rfreqs15t, odd_repetition_samples=True)
    irft_to_15 = create_freq2time(15, freqs=rfreqs15t, odd_repetition_samples=True)
    randvec2rtime15 = randvec[:15].astype(np.complex128).view(float) @ irft_from_15
    randvec2rfreq15 = randvec[:15] @ rft_from_15
    randvec2irfft = np.fft.irfft(randvec[:15])
    randvec2rfft = np.fft.rfft(randvec[:15])
    randvec2rfreq152randvec = randvec2rfreq15.view(float) @ irft_to_15
    randvec2rfreq152irfft = np.fft.irfft(randvec2rfreq15, 15)
    assert np.allclose(rfreqs15t, np.fft.rfftfreq(15))
    assert np.allclose(rfreqs15f, np.fft.rfftfreq(28))
    assert np.allclose(randvec2rtime15, randvec2irfft)
    assert np.allclose(randvec2rfreq15, randvec2rfft)
    assert np.allclose(randvec2rfreq152randvec, randvec2rfreq152irfft)
    assert np.allclose(randvec2rfreq152randvec, randvec[:15])
    rfreqs16t = fftfreq(repetition_samples=16, complex=False)
    rfreqs16f = fftfreq(16, complex=False)
    irft16 = create_freq2time(freqs=rfreqs16f)
    rft16 = create_time2freq(16, freqs=rfreqs16t)
    randvec2rtime16 = randvec[:16].astype(np.complex128).view(float) @ irft16
    randvec2rfreq16 = randvec[:16] @ rft16
    randvec2irfft = np.fft.irfft(randvec[:16])
    randvec2rfft = np.fft.rfft(randvec[:16])
    assert np.allclose(rfreqs16t, np.fft.rfftfreq(16))
    assert np.allclose(rfreqs16f, np.fft.rfftfreq(30))
    assert np.allclose(randvec2rtime16, randvec2irfft)
    assert np.allclose(randvec2rfreq16, randvec2rfft)

    # [ 0, 1, 2, 3, 4, 3, 2, 1]
    #irft9_16 = create_freq2time(16, fftfreq(9, complex = False))
    #rft16_30 = create_time2freq(16, fftfreq(30, complex = False))
    #randvec2
    #assert np.allclose((randvec @ rft16) @ irft16, randvec)
    
    
    # sample data at a differing rate
    time_rate = np.random.random() * 2
    freq_rate = 1.0
    freqs = np.fft.fftfreq(len(randvec))
    rescaling_freqs = fftfreq(len(randvec), freq_sample_rate = freq_rate, sample_rate = time_rate)
    rescaling_ift = create_freq2time(freqs = rescaling_freqs)
    rescaling_ft = create_time2freq(freqs = rescaling_freqs)
    rescaled_time_data = np.array([
        np.mean([
            randvec[freqidx] * np.exp(2j * np.pi * freqs[freqidx] * sampleidx / time_rate)
            for freqidx in range(len(randvec))
        ])
        for sampleidx in range(len(randvec))
    ])
    assert np.allclose(rescaled_time_data, randvec @ rescaling_ift)
    assert np.allclose(rescaled_time_data, np.linalg.solve(rescaling_ft.T, randvec))
    unscaled_freq_data = rescaled_time_data @ rescaling_ft
    unscaled_time_data = unscaled_freq_data @ ift16
    assert np.allclose(unscaled_freq_data, randvec)
    assert np.allclose(unscaled_time_data, randvec2time)
    assert np.allclose(np.linalg.solve(rescaling_ift.T, rescaled_time_data), randvec)

    # extract a repeating wave
    longvec = np.empty(len(randvec)*100)
    short_count = np.random.random() * 4 + 1
    short_duration = len(longvec) / short_count
    for idx in range(len(longvec)):
        longvec[idx] = randvec[int((idx / len(longvec) * short_duration) % len(randvec))]
    shortspace_freqs = fftfreq(None, complex = False, dc_offset = True, repetition_samples = len(randvec))
    longspace_freqs = fftfreq(len(shortspace_freqs), complex = False, dc_offset = True, repetition_samples = short_duration)
    assert np.allclose(longspace_freqs, shortspace_freqs / (short_duration / len(randvec)))
    inserting_ift = create_freq2time(len(longvec), longspace_freqs, wavelet = STEP)
    extracting_ft = create_time2freq(len(longvec), longspace_freqs, wavelet = STEP)
    extracting_ift = create_freq2time(freqs = shortspace_freqs, wavelet = STEP)
    unextracting_ft = create_time2freq(freqs = shortspace_freqs, wavelet = STEP)
    inserting_spectrum = randvec @ unextracting_ft
    assert np.allclose(inserting_spectrum @ extracting_ift, randvec)
    assert np.allclose(longvec, inserting_spectrum @ inserting_ift)
    extracted_freqs = longvec @ extracting_ft
    extracted_randvec = extracted_freqs @ extracting_ift
    assert np.allclose(extracted_randvec, randvec)

if __name__ == '__main__':
    test()
