function [lfps, fs, trial_type, trial_index_extracted, class] = load_lfps(data_directory, command_line, trial_file, time_interval, alpha_plex, electrode, pass_band, remove_mean)
% LOAD LFPS: Load lfp data for each trial.
% class output added by JK 2020-12-25
%
%   Input:
%       data_directory = the data directory where all the data is
%       command_line = the grab call
%       trial_file = the temp file with data for each trial
%                    (trial stamp, start time, trial type, stack, class)
%       time_interval = trial duration after start time (sec)
%       alpha_plex = type of lfp data (0 for alpha-omega, 1 for plexon)
%       electrode = electrode number
%
%   Output:
%
% Output
%       lfps = Matrix of Aligned and Truncated LFP Trials
%     	fs = Sampling rate of LFP Signal (kHz)
%     	trial_type = Vector of Trial Types for each lfps
%     	trial_index_extracted = Vector of Reach Trials for which LFPs exist
%

    % Load reach and lfp data (and filter).
   % [trial_stamps_reach, start_time, trial_type] = read_trial_file(trial_file);
    %if contains(command_line,'-mc ')
    [trial_stamps_reach, start_time, ~, trial_type,class] = read_trial_file(trial_file); % JK added 2020-12-25 (trial_type is stack)
    %else
    %[trial_stamps_reach, start_time, trial_type, ~,class] = read_trial_file(trial_file); % JK added 2020-12-27 
    % (trial_type is 1,...,8 for first stack, 9,...,16 for second stack (assuming each stack has 8 classes)
    %end
    switch alpha_plex
        case {0, 1}
            [trial_stamps_lfp, trial_times_lfp, lfp, fs] =   ...
                    load_lfp_data(electrode, data_directory, alpha_plex);
        case 2
            [trial_stamps_lfp, trial_times_lfp, lfp, fs] =   ...
                    load_oe_lfps(data_directory, electrode);
    end

    % Notch filter
    w0 = 60 / (fs * 1e3 / 2);
    bw = 1 / 10 / (fs * 1e3 / 2);
    [b_notch, a_notch] = iirnotch(w0, bw);
    lfp = filtfilt(b_notch, a_notch, lfp);

    % Band pass filter
    if ~isempty(pass_band)

        % Design and apply a pass-band cheby2 filter
        f_pass_hp = pass_band(1);
        f_pass_lp = pass_band(2);
        f_stop_hp = f_pass_hp - pass_band(3);
        f_stop_lp = f_pass_lp + pass_band(3);
        stop_attenuation = 80;
        pass_ripple = 0.1;
        D = fdesign.bandpass(f_stop_hp, f_pass_hp, f_pass_lp, f_stop_lp, ...
                             stop_attenuation, pass_ripple,              ...
                             stop_attenuation, fs * 1e3);
        H = design(D, 'cheby2');
        lfp = filtfilt(H.sosMatrix, H.ScaleValues, lfp);
    end

    % Sanity check for resets in reach trial stamps.
    if any(diff(trial_stamps_reach) < 1)
        error('reach trial trial_stamps are not monotonic')
    end

    %keyboard

    % Get reach trial times to compare to lfp trial times
    idx_spaces = strfind(command_line, ' ');
    idx_first_space = idx_spaces(1);
    status = system([command_line(1:(idx_first_space - 1)), ' -R5',     ...
		             command_line(idx_first_space:end), ' > grab.temp']);
    fid = fopen('grab.temp', 'r'); 
    grab_data = textscan(fid, '%s %s %s %s %s %s %s %s %s %s %s');
    fclose(fid);
    delete('grab.temp');
    trial_times_reach = str2num(char(grab_data{5})) * 1e3;  % Convert to msec

    % Sanity check: Make sure you have as many reach trial times as stamps
    if numel(trial_times_reach) ~= numel(trial_stamps_reach)
        error('mismatch in REACH trial stamps and times')
    end

    %keyboard

    % LFP trial stamps may reset. Find the trial stamps that match 
    % reach's trial stamps by using the inter-trial intervals.
    idx_reset = [1, (find(diff(trial_stamps_lfp) < 1) + 1)];
    n_matches = 0;
    n_inter_trial_intervals_match = [];
    n_inter_trial_intervals = [];    
    for i_reset = 1:numel(idx_reset)

        % Find the end of this run of monotonicly increasing trial stamps.
        if i_reset < numel(idx_reset)
            idx_end = idx_reset(i_reset + 1) - 1;
        else
            idx_end = numel(trial_times_lfp);
        end

        % Get the lfp stamps and times for this particular reset
        these_trial_stamps_lfp = trial_stamps_lfp(idx_reset(i_reset):idx_end);
        these_trial_times_lfp = trial_times_lfp(idx_reset(i_reset):idx_end);

        % Find the interesection of trial stamps. Extract those times
        [~, idx_reach, idx_lfp] = intersect(trial_stamps_reach, ... 
                                            these_trial_stamps_lfp);
        trial_times_lfp_extracted_temp = these_trial_times_lfp(idx_lfp);
        trial_times_reach_extracted_temp = trial_times_reach(idx_reach);

        % Compare intertrial intervals between lfps and reach.
        inter_trial_interval_lfp = diff(trial_times_lfp_extracted_temp);
        inter_trial_interval_reach = diff(trial_times_reach_extracted_temp);
        inter_trial_interval_diff = abs(inter_trial_interval_lfp    ...
                                        - inter_trial_interval_reach);

        % Track the number of ITIs that match (for debugging).
        n_inter_trial_intervals_match = [n_inter_trial_intervals_match,     ...
                                         sum(inter_trial_interval_diff < 500)];
        n_inter_trial_intervals = [n_inter_trial_intervals,     ...
                                   numel(n_inter_trial_intervals)];

        % If ITIs between reach and lfps all have discrepancies < 500 ms, 
        %   count this as a match and "extract" the data.
        if ~isempty(idx_reach) & all(inter_trial_interval_diff < 500)
            n_matches = n_matches + 1;
            trial_times_lfp_extracted = these_trial_times_lfp(idx_lfp);
            trial_times_reach_extracted = trial_times_reach(idx_reach);
            trial_type = trial_type(idx_reach);
            trial_stamps_extracted = trial_stamps_reach(idx_reach);
            start_time_extracted = start_time(idx_reach);
            trial_index_extracted = idx_reach;  
        end
    end

    % Sanity check: Make sure one and only one match was found
    if n_matches < 1
        prct_match_str = sprintf('%d: %d/%d, ',     ...
                                 [1:numel(n_inter_trial_intervals); 
                                  n_inter_trial_intervals_match; 
                                  n_inter_trial_intervals])
        prct_match_str = prct_match_str(1:(end-2))
        error(['No matches between lfp and reach trials found!!\n', ...
               ' - %d potential matches\n', ...
               ' - Percent match for each partial: [%s]'], ...
               numel(idx_reset), prct_match_str)
    elseif n_matches > 1
        error('Multiple matches between lfp and reach trials found!!')
    end

    % Figure out start and stop of trials
    n_samples_per_interval = round(time_interval * fs);
    trial_start_idx = round(fs * (trial_times_lfp_extracted     ...
                                  + start_time_extracted));
    trial_end_idx = trial_start_idx + n_samples_per_interval - 1;

    % Remove trials that go past the extent of the lfp recording.
    is_excess = trial_end_idx > length(lfp);
    if any(is_excess)
        disp(sprintf(['WARNING: %d of %d trials ran past the LFP signal'    ...
                      ' and will be ignored'],                              ...
                     sum(is_excess), numel(is_excess)));
    end 
    trial_times_lfp_extracted = trial_times_lfp_extracted(~is_excess);
    trial_type = trial_type(~is_excess);
    trial_stamps_extracted = trial_stamps_extracted(~is_excess);
    start_time_extracted = start_time_extracted(~is_excess);
    trial_index_extracted = trial_index_extracted(~is_excess);
    trial_start_idx = trial_start_idx(~is_excess);
    trial_end_idx = trial_end_idx(~is_excess);

    % Extract and align data (n samples may vary by 1 sample from run to run).
    for i_trial = 1:numel(trial_start_idx)
       lfps(i_trial, :) = lfp(trial_start_idx(i_trial):trial_end_idx(i_trial));
    end

    % Remove mean ("evoked") response.
    if remove_mean
        lfps = bsxfun(@minus, lfps, mean(lfps, 1));
    end
end


function [LFPTrialStamps,LFPTrialTimes,LFPSignal,LFPSamplingRate] = load_lfp_data(electrode, datadirectory, alpha_plex)
%Loads lfp data and performs some necessary checks and corrections 
%
% NOTE: 	For alpha-omega Harry's encoding (alpha_plex=0) function assumes that directory file listing
%		comes out in order of recording time.  This is will only be the case if the naming convention
%		used by Harry is preserved.  Specifically, when alpha_plex=0 only load files for which filenames
%		are the same date and for which the file number has the same number of total digits (3 in Harry's
%		data). E.g. 13-04-01-001.mat, 13-04-01-002.mat work together.  13-04-01-0003.mat will not work
%		with the -001.mat and -002.mat files, but will work with 13-04-01-0004.mat
%
% Function Parameters:	electrode (channel of electrode), 
% 			datadirectory (directory of the REACH file)
%			alpha_plex (alpha-omega (0) or plexon (1) filetypes?)
%
%
% Output:	LFPTrialStamps (Stamp of each trial) 
%		LFPTrialTimes (Time of each trial)
%		LFPSignal (Raw LFP recordings)
%		LFPSamplingRate (Sampling rate of LFP recordings)
%
%%%%%NOTE:  SPURIOUS BIT and DROPPED PULSES corrections in the Alpha-Omega data assume 12.5KHz sampling rate in the encoding line and clock line%%%%%


    warning off all
    if(alpha_plex==0)
        lfpdirectory=[datadirectory(1:16) '/LFPs/' datadirectory(18:end) '/'];
    elseif(alpha_plex==1)
        % LHS - replaced with next line lfpdirectory=[datadirectory(1:end) '/'];
        slashes=strfind(datadirectory,'/');
        lastslash=slashes(end);
        lfpdirectory=datadirectory(1:lastslash);
        lfpfile=datadirectory(lastslash+1:end);
    end
    lfpdirlist=dir([lfpdirectory '*.mat']);
    LFPSignal=[];

    %for ii = 1:size(lfpdirlist,1)
    if(alpha_plex==0)  %Read in Alpha Omega data
        
        for ii = 1:size(lfpdirlist,1)
        
            if(electrode==1) %Load up the data from the correct electrode
                S=load([lfpdirectory lfpdirlist(ii).name],'CLFP1','CLFP1_KHz','CDIG1_Up','CDIG2_Up','CDIG2_KHz','CDIG1_Down','CDIG2_Down');
                CLFP=S.CLFP1;
                CLFP_KHz=S.CLFP1_KHz;
            elseif(electrode==2)
                S=load([lfpdirectory lfpdirlist(ii).name],'CLFP2','CLFP2_KHz','CDIG1_Up','CDIG2_Up','CDIG2_KHz','CDIG1_Down','CDIG2_Down');
                CLFP=S.CLFP2;
                CLFP_KHz=S.CLFP2_KHz;
            elseif(electrode==3)
                    S=load([lfpdirectory lfpdirlist(ii).name],'CLFP3','CLFP3_KHz','CDIG1_Up','CDIG2_Up','CDIG2_KHz','CDIG1_Down','CDIG2_Down');
                    CLFP=S.CLFP3;
                    CLFP_KHz=S.CLFP3_KHz;
            elseif(electrode==4)
                    S=load([lfpdirectory lfpdirlist(ii).name],'CLFP4','CLFP4_KHz','CDIG1_Up','CDIG2_Up','CDIG2_KHz','CDIG1_Down','CDIG2_Down');
                    CLFP=S.CLFP4;
                    CLFP_KHz=S.CLFP4_KHz;
            end

            %lfpdirlist(ii).name
            if(~isfield(S,'CDIG1_Up') | ~isfield(S,'CDIG1_Down') | ~isfield(S,'CDIG2_Up') | ~isfield(S,'CDIG2_Down'))
                ~isfield(S,'CDIG1_Up')
                ~isfield(S,'CDIG1_Down')
                ~isfield(S,'CDIG2_Up')
                ~isfield(S,'CDIG2_Down')
                warning_msg=['No pulses found in one of the 4 digital variables.  Skipping file ' lfpdirlist(ii).name]
                continue;  % If digital pulses don't exist in the first file, skip the file and move on to the next loop iteration
            end

            if(isempty(LFPSignal))
                LFPSignal=CLFP;
                CDIG1_Up=S.CDIG1_Up;
                CDIG2_Up=S.CDIG2_Up;
                CDIG2_Down=S.CDIG2_Down;
                CDIG1_Down=S.CDIG1_Down;

            else
                LFPSignal=[LFPSignal CLFP];
                S.CDIG1_Up=S.CDIG1_Up+double(int32(lfpsize*CDIG2_KHz/LFPSamplingRate));
                CDIG1_Up=[CDIG1_Up S.CDIG1_Up];
                S.CDIG2_Up=S.CDIG2_Up+double(int32(lfpsize*CDIG2_KHz/LFPSamplingRate));
                CDIG2_Up=[CDIG2_Up S.CDIG2_Up];

                S.CDIG2_Down=S.CDIG2_Down+double(int32(lfpsize*CDIG2_KHz/LFPSamplingRate));
                CDIG2_Down=[CDIG2_Down S.CDIG2_Down];
                S.CDIG1_Down=S.CDIG1_Down+double(int32(lfpsize*CDIG2_KHz/LFPSamplingRate));
                CDIG1_Down=[CDIG1_Down S.CDIG1_Down];

            end
            
            lfpsize=length(LFPSignal);
            LFPSamplingRate=CLFP_KHz;
            CDIG2_KHz=S.CDIG2_KHz;
            
            %TRIAL STAMP BINARY CODE ERROR CHECKING- START
                    
            %CORRECT SPURIOUS BITS - START - NOTE: ASSUMES 12.5KHz sampling rate%
            CDIG1_Down=[CDIG1_Down(diff(CDIG1_Down)~=0) CDIG1_Down(end)];
            CDIG2_Down=[CDIG2_Down(diff(CDIG2_Down)~=0) CDIG2_Down(end)];
            CDIG1_Up=[CDIG1_Up(diff(CDIG1_Up)~=0) CDIG1_Up(end)];
            CDIG2_Up=[CDIG2_Up(diff(CDIG2_Up)~=0) CDIG2_Up(end)];
                        
            if(size(CDIG1_Down,2) > size(CDIG1_Up,2)) %If we start or stop recording during the time-stamp and thus have different size Up and Down vectors, remove the extra pulse.
                CDIG1_Down=CDIG1_Down(2:end);
            elseif(size(CDIG1_Down,2) < size(CDIG1_Up,2))
                CDIG1_Up=CDIG1_Up(2:end);
            end
                
            removeidx=find((CDIG1_Down-CDIG1_Up)<3); %Find all pulses that lasted for less than X samples.
            CDIG1_Down(removeidx)=[]; %Remove those pulses from CDIG1_Down
            CDIG1_Up(removeidx)=[]; %Remove those pulses from CDIG1_Up

            %CORRECT SPURIOUS BITS - END%
                
            %----------------------------%

            %CORRECT DROPPED PULSES - START - NOTE:  ASSUMES 12.5KHz sampling rate%
            %difference of more than groupDelta indicates a new (pseudo-)group
            groupDelta = 15;
            groupJump = [1 diff(CDIG2_Up) > groupDelta];

            %# number the groups
            groupNumber = cumsum(groupJump);

            %# count, for each group, the numbers.
            groupCounts = hist(groupNumber,1:groupNumber(end));

            %# if a group contains fewer than 17 entries, throw it out
            badGroup = find(groupCounts < 17);
            CDIG2_Up(ismember(groupNumber,badGroup)) = [];

            %CORRECT DROPPED PULSES - END%

            %TRIAL STAMP BINARY CODE ERROR CHECKING- END

            encoding=CDIG1_Up;
            clock=CDIG2_Up;

            rangeclock=[];

            for i = -2:2 %Define range of clock line samples to check for pulses
                rangeclock = [rangeclock; clock+i];
            end

        end

    elseif(alpha_plex==1)  %Read in Plexon data

        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        %%%%%   NOTE: This chunk of code only reads one (specified by grab) LFP file  %%%%
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

        if (~(exist([lfpdirectory lfpfile])==2))
            error(['The LFP MAT [' lfpfile '] file could not be found in ' lfpdirectory])
        else 
            S=load([lfpdirectory lfpfile],'tsevs','allad','adfreq');
            tsevs=S.tsevs;
            allad=S.allad;

            %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
            %%%%  NOTE:  ERIC'S MAPPING   %%%%%%%
            %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

            channel=[];
            for i = 1:16
                if(length(allad{i})>1)
                    channel=[channel i];
                end
            end
            electrode=channel(electrode);
                    
            %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
            %%%%  NOTE:  ERIC'S MAPPING - END  %%%%%%%
            %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
                    
            LFPSamplingRate=S.adfreq/1000;
            LFPSignal=allad{electrode}';
            
            PLEXDIG_Up=[];  %Create a simulated clock line for the plexon
            i=1;
            while i <= size(tsevs{2},1)
                PLEXDIG_Up=[PLEXDIG_Up tsevs{2}(i)+(0:16)*0.01]; %17 pulses 0.01 seconds appart
                i=find(tsevs{2} > tsevs{2}(i)+16*0.01,1); %find the next trial
            end

            encoding=int32(10000*tsevs{2}');
            clock=int32(10000*PLEXDIG_Up);
            
            rangeclock=[];
            for i = -2:2
                rangeclock = [rangeclock; clock+i];
            end
        end
    end

            
        
    %Decode LFP Trial Stamps
    Match = sum(ismember(rangeclock, encoding),1) > 0; %Check for encoding-line pulses at each clock-line pulse, including the pulse range

    %  If any of the values in 'clock' matches the values in data, a '1' (true) is registered into 'M'
    N=reshape(Match,17,(size(Match,2)/17))'; %reshape vector Match into an M x 17 matrix: Each row is the binary code for a trial
    LFPTrialStamps=bin2dec(num2str(N(:,2:end)));  %convert trial numbers from binary to base 10

    %Calculate LFP Trial Times
    if(alpha_plex==0)
        O=reshape(clock,17,(size(clock,2)/17))'; %Use the clock to generate trial times
        LFPTrialTimes=O(:,1)/CDIG2_KHz; %Trial times in milliseconds
    elseif(alpha_plex==1)
        O=reshape(PLEXDIG_Up,17,(size(PLEXDIG_Up,2)/17))'; %Use the simulated clock-line to generate trial times
        LFPTrialTimes=O(:,1)*1000; %Trial Times in milliseconds.
        LFPTrialTimes=LFPTrialTimes(LFPTrialStamps>0); %Removes Trial Times when Time stamp is 0 in plexon stray pulses
    end
        
    LFPTrialStamps=LFPTrialStamps(LFPTrialStamps>0); %Removes Trial Stamps of 0 created by random stray pulses within a trial when using the plexon.  

    %%%%%%%%%%%%%%%%%%%%%%%%%%%%
    %%%%READ IN LFP DATA END%%%%
    %%%%%%%%%%%%%%%%%%%%%%%%%%%%

end


function [trial_stamps_lfp, trial_times_lfp, lfp, fs] = load_oe_lfps(data_directory, electrode)
% LOAD OPEN EPHYS LFPS
%   Load "converted" lfps from OE binary files.
%

    % Find the lfp files (each file belongs to a run)
    lfp_dir = sprintf('%s/lfps/ch%02d.lfp', data_directory, electrode);
    lfp_files = dir([lfp_dir, '/*.lfp']);

    % Load the data and concatenate into a single vector.
    for i_file = 1:numel(lfp_files)
        fid = fopen([lfp_dir, '/', lfp_files(i_file).name], 'r');
        lfp_by_file{i_file} = fread(fid, 'int32').';
        fclose(fid);
    end
    lfp = [lfp_by_file{:}];

    % Hard code fs for now
    fs = 1;
    trial_stamps_lfp = [];
    trial_times_lfp = [];

end
