function [freq_cent, wt] = wavelet_transform(lfps, fs, freq)
% WAVELET TRANSFORM:
%   Compute wavelet transform for many lfop trials.
%
%   Input:
%       lfps = the lfp snippets (trial x time)
%       fs = sampling freq (kHz)
%       freq = the frequencies to look at
%
%   Output:
%       freq_cent = the center frequencies of the wavelets
%                   (not 100% sure these will equate freq input).
%       wt = the wavelet transformation (trial x freq x time).
%

    % CWT prep
    scales = fs * 1e3 ./ freq;
    freq_cent = scal2frq(scales, 'cmor1-1', 1 / (fs * 1e3));

    % Pad lfps with reflected data
    mid = ceil(size(lfps, 2) / 2);
    lfp_pad = [fliplr(lfps(:, 1:mid)), lfps, fliplr(lfps(:, (mid+1):end))];

    % Compute wavelet transforms
    for i_trial = 1:size(lfps, 1)
        wt_pad = cwt(lfp_pad(i_trial, :), scales, 'cmor1-1');
        wt(i_trial, :, :) = wt_pad(:, (mid + 1):(end - mid + 1));
    end

end
