function [S] = mt_dtft_gram(lfps, fs, time_window, time_step, freq, TW, n_tapers)
% MULTI-TAPER DISCRETE-TIME-FOURIER-TRANSFORM SPECTROGRAM:
%   Compute a time-frequency plot using a DTFT with a multipaer method.
%
%   Input:
%       lfps = lfp signals (trials x samples)
%       fs = lfp sampling rate (kHz).
%       time_window = sliding window size (ms)
%       time_step = sliding window step size (ms)
%       freq = frequencies (Hz).
%       TW = time-half-bandwidth product
%       n_tapers = number of tapers
%
%   Output:
%       S = Spectrogram (trial x frequency x time x taper)
%

    % This is hard coded. Change the value here in the RARE CASE we don't
    % want to use the GPU
    USE_GPU = 1;

    % Set up windowing.
    n_samps_per_window = floor(time_window * fs);
    duration = size(lfps, 2) / fs;
    time = 0:(1/fs):duration;
    window_start = 0:time_step:(duration - time_window);
    window_end = window_start + time_window;

    % Step through each time point, computing DTFTs.
    for i_time = 1:numel(window_start)
        is_window = time >= window_start(i_time) & time < window_end(i_time);
        S(:, :, :, i_time) = mt_dtft(lfps(:,is_window), TW, n_tapers,   ...
                                     freq, fs, USE_GPU);
    end

    % Re-order matrix dimensions (trial x freq x time x taper)
    S = permute(S, [1, 2, 4, 3]);
end


function [pxx_out] = mt_dtft(xin,nw,nTapers,fin,fs,use_gpu)
% DISCRETE TIME FOURIER TRANSFORM MULTITAPER METHOD
%   Compute signals' DTFT with a multitaper approach.
%
%   INPUT:
%       xin = The time signal (Trials x Time).
%       nw  = Time-halfbandwidth product.
%       fin = Frequencies to resolve.
%       fs  = Sampling rate for xin.
%
%   OUTPUT:
%       xx = DTFT estimates for each taper and signal (Freq x Taper x Trial).
%       v  = Eigenvalues of each taper.
%

% AUTHOR: Charles D Holmes
% EMAIL:  chuck@eye-hand.wustl.edu
%

    %%%% CONSTANTS %%%%
    SIZE_MAX = 16e6;    % bytes


    %%%% PROCEDURE %%%%

    % Get dimensions from inputs
    [nTrials,nTime] = size(xin);
    nFreq   = length(fin);

    % Compute tapers
    [tapers,v] = dpss(nTime,nw,nTapers);

    % Convert arrays to gpuArrays. Force vectos to be col vectors
    if use_gpu
        xin    = gpuArray(xin);
        fin    = gpuArray(fin(:));
        tin    = gpuArray((0:(nTime-1))/(fs * 1e3));
        tin    = tin(:);
        tapers = gpuArray(tapers);
        v      = gpuArray(v);
    else
        fin = fin(:);
        tin = (0:(nTime-1))/(fs * 1e3);
        tin = tin(:);
    end

    % Set up for job splitting - Minimize GPU usage.
    numelMax  = SIZE_MAX/8;     % elements
    maxTrials = floor(numelMax/(nFreq*nTime*nTapers));
    idx(1,:)  = 1:maxTrials:nTrials;
    idx(2,:)  = [idx(1,2:end),nTrials];

    % Split up dtft calculation
    xx = NaN(nFreq,nTapers,nTrials,'Like',xin);
    for iSplit = 1:size(idx,2)
        idxArr = idx(1,iSplit):idx(2,iSplit);
        xx(:,:,idxArr) = dtftMtmHelper(xin(idxArr,:),fin,tin,tapers);
    end

    % Adaption setup
    varx = mean(xin.^2,2).';
    Sk   = abs(xx).^2;
    S    = (Sk(:,1,:)+Sk(:,2,:))/2;
    S1   = zeros(size(S),'Like',S);
    tol  = 5e-4*varx/nFreq;
    a    = (1-v)*varx;
    pxx  = xx;

    % Adaptive algorithm
    while any(squeeze(sum(abs(S-S1)/nFreq,1))>tol.')
        if use_gpu
            B = pagefun(@mtimes,S,repmat(v',[1,1,nTrials])) ...
                + permute(repmat(a,[1,1,nFreq]),[3,1,2]);
        else
            Btemp = repmat(v',[1,1,nTrials])
            for ii = 1:nTrials
                B(:,:,ii) = S(:,:,ii)*Btemp(:,:,ii)
            end
            B = B + permute(repmat(a,[1,1,nFreq]),[3,1,2])
        end
        b = repmat(S,[1,nTapers,1])./B;
        wk = (b.^2).*permute(repmat(v,[1,nFreq,nTrials]),[2,1,3]);
        Wk = wk./repmat(mean(wk,2),[1,nTapers,1]);
        %S1 = sum(wk.*Sk,2)./sum(wk,2);
        S1 = mean(Wk.*Sk,2);
        Stemp=S1; S1=S; S=Stemp;  % swap S and S1
        pxx = sqrt(Wk).*xx;
    end

    pxx_out = gather(permute(pxx, [3,1,2])); % Switch to: Trial x Freq x Taper

end

function xx = dtftMtmHelper(xin,fin,tin,tapin)
% DISCRETE-TIME FOURIER TRANSFORM MULTITAPER METHOD HELPER
%   Convert vectors into 4-D matrices for the DTFT calulcation.
%

    % Recompute sizes in this subfunction
    [nTrials,nTime] = size(xin);
    nFreq           = length(fin);
    nTapers         = size(tapin,2);

    % Make full matrices: Freq x Taper x Trial x Time
    x  = permute(repmat(xin,[1,1,nFreq,nTapers]),[3,4,1,2]);
    f  = permute(repmat(fin,[1,nTapers,nTime,nTrials]),[1,2,4,3]);
    w  = permute(repmat(tapin,[1,1,nFreq,nTrials]),[3,2,4,1]);
    t  = permute(repmat(tin,[1,nFreq,nTapers,nTrials]),[2,3,4,1]);

    % DTFT calculation
    xx = sum(w.*x.*exp(-2*pi*1i*f.*t),4);   % Freq x Taper x Trial
end
