function cohSpikeLFP(suffix)
% WAVELET ANALYSIS
%   Determine phase consistency between LFP and spikes over the frequency
%   spectrum. LFPs are decomposed with the wavelet transform.
%

    % Read parameters
    Params = read_params(suffix);
    freq = Params.freq_start_1 : Params.freq_step_1 : Params.freq_end_1;

    % Load lfp data
    [lfps, fs, trial_type, trial_index] = ...
            load_lfps(Params.data_directory, Params.command_line,   ...
                      Params.trial_file, Params.time_interval,      ...
                      Params.alpha_plex, Params.electrode_1, [],    ...
                      Params.remove_mean);

    % Load spike data
    spikes = load_spikes(Params.trial_file, Params.spike_file,      ...
                         Params.time_interval, fs, trial_index);

    % Thin spikes: for each trial type, if there are more spikes than
    %   the thinning number, remove the excess.
    if Params.thin_spikes
        spikes = limit_spikes(spikes, trial_type, Params.n_spikes_limit,    ...
                              Params.seed_rng);
    end

    % Get all trial types and count nTrials
    type_list = unique(trial_type);
    n_trials = arrayfun(@(x) sum(trial_type == x), type_list);

    % Thin spikes: for each trial type, if there are more spikes than
    %   the thinning number, remove the excess.
    if Params.thin_spikes
        spikes = limit_spikes(spikes, trial_type, Params.n_spikes_limit,    ...
                              Params.seed_rng);
    end

    % Compute wavelet transform
    [freq_cent, wt] = wavelet_transform(lfps, fs, freq);

    % Only keep trials with spikes.
    has_spikes = sum(spikes, 2) > 0;
    wt = wt(has_spikes, :, :);
    spikes = spikes(has_spikes, :);
    trial_type = trial_type(has_spikes);
    if not(all(has_spikes))
        fprintf('%d trials have no spikes. Dropped them.\n',     ...
                sum(not(has_spikes)));
    end

    %% Start PPC calculation
    %   This is derived from Vinck et al. 2012, specifically from the
    %   the equation for P2.

    % First, get the mean vector for each trial
    for i_trial = 1:size(wt, 1)
        wt_spike = wt(i_trial, :, find(spikes(i_trial, :)));
        vec = wt_spike ./ abs(wt_spike);
        vec_mean(i_trial, :) = mean(vec, 3);
    end

    % Then, sum the mean vectors and the mean vector square magnitudes.
    %   Use those values to compute PPC.
    for i_type = 1:numel(type_list)
        is_type = trial_type == type_list(i_type);
        M = sum(is_type);
        S = abs(sum(vec_mean(is_type, :), 1)).^2;
        SS = sum(abs(vec_mean(is_type, :)).^2, 1);
        ppc(i_type, :) = (S - SS) / M / (M - 1);
    end

    % Output data
    fid = fopen(['data', suffix], 'w');
    fprintf(fid, '%s\n', Params.command_line);
    fprintf(fid, '%s\n', Params.data_directory);
    fprintf(fid, '%s\n', Params.params_str);
    fprintf(fid, '%s\n', sprintf('%d ', n_trials));
    fprintf(fid, '\n');
    for i_type = 1:length(type_list)
        fprintf(fid, '%s\n', sprintf('%g ', ppc(i_type, :)));
    end
    fclose(fid);
    
end
