function limit_spikes(spikes, trial_type, limit, seed)
% 

    % Random seed
    rng(seed, 'twister')

    % Loop over every trial type
    type_list = unique(trial_type);
    for i_type = 1:numel(type_list)

        % Find all the spikes for this trial type.
        is_type = trial_type == type_list(i_type);
        spike_index = find(bsxfun(@and, spikes == 1, is_type));
        n_spikes = numel(spike_index);

        % If any go past the limit (per type), remove the excess randomly.
        if n_spikes > limit
            remove = randperm(n_spikes, n_spikes - limit);
            spikes(spike_index(remove)) = 0;
        end
    end
end
