% RunFaceRecog.m
% Main function that reads in files, applies pre-processing, calls face
% recognition algorithm, and then plots resulting accuracy.
%
% Assumptions:
%   - you have a set of pre-processed face images (cropped, normalized, ..)
%   - multiple photographs of each individual to be recognized are
%   available
%   - (maybe?) non-face training data is available?
%
% Alan Brooks and Li Gao
% Northwestern University - ECE 432 Advanced Computer Vision
 
% Mod history:
% 2004:
%    5-May  initial version
%   11-May  added UMIST database capability
%   12-May  prettied the plots, lowered memory requirements
%   13-May  added subfigure, speed up
%   14-May  faster again (caching training data & svd)
%   18-May  began FisherFace technique (Li Gao)
%   19-May  tweaked FisherFace, more work (AB)
%   20-May  tried smaller pictures (AB)
%   25-May  more tries on FisherFace (AB)
%   26-May  added class to ALAN database, smaller UMIST, fix Fisher (AB)
%   31-May  added maxFishFase, figure saving, produced final plots (AB)
%    1-Jun  tried unpreproc ALAN database (AB)

% =========================================================================
function RunFaceRecog()

% ------ User Parameters ------
% database type
dbType = 'UMIST'; % 'ALAN' or 'UMIST' supported so far

% algorithm & feature extraction parameters
frAlgorithm = 'eigen';     % 'eigen' 'fisher' ...?
nEigFace = 21; %21;         % pick # of principal components (PCA)
maxFishEigFace = 41; %Inf;  % pick max number of PCA components to use
                            % in first part of Fisherface algorithm
threshFace = 15; %10; %5;   % thresholds
threshClass = 8; %5; %3;

% caching parameters
train.forceReread = 0;      % 0 = use cached inputs if found, 1 = don't
train.forceNewSvd = 0;      % 0 = use cached SVD, 1 = don't (PCA)
train.forceNewScat = 0;     % 0 = use cached scatters, 1 = don't (LDA)

% plotting parameters
plots.intermediateOn = 1;
plots.finalOn = 1;
plots.savePlotsOn = 1;

f = filesep;
switch dbType
    case 'UMIST'
        % data location & type
        basePath = ['C:\Documents and Settings\alanb\My Documents\NWU\' ...
                'ECE 432\final project\face_databases\UMIST'];
        %basePath = ['D:\My Documents\NWU\ECE 432 Comp Vision\' ...
        %        'face_databases\UMIST'];
        %basePath = ['C:\Documents and Settings\lgao\My Documents\' ...
        %        'ECE432\Projects\Facial Recognition\'];
        %db = 'croppedConsolPng';
        %db = 'croppedConsolPng46x56';
        %db = 'croppedConsolPng23x28';
        db = 'croppedConsolPng23x28_subset'; % 3 train, 1 recog image
        imType = 'png';
    case 'ALAN'
        % data location & type
        if isunix
            basePath = ['/Users/alanb/Documents/Riting & Labs/' ...
                    'Northwestern/Year 2 Q3 (Spring)/' ...
                    'EE 432 Computer Vision/final project/' ...
                    'face_databases/MUstudentPhotos'];        
        else
            basePath = ['C:\Documents and Settings\alanb\My Documents\' ...
                    'NWU\ECE 432\final project\face_databases\' ...
                    'MUstudentPhotos'];
            %basePath = ['D:\My Documents\NWU\ECE 432 Comp Vision\' ...
            %        'face_databases\millikin'];
        end
        %db = 'thumbsMoreThan2prepro40x60';
        %db = 'thumbsMoreThan2prepro10x15';
        %db = 'tmbwClassOval40x60';
        db = 'tmbwClassOval40x60_2'; % changed which ones in recog
        %db = 'tmbwClassppNo60x80g'; % no pre-processing
        imType = 'png';
    otherwise
        error(sprintf('%s database type not understood.',dbType))
end

train.path = fullfile(basePath,[db f 'train']);
recog.path = fullfile(basePath,[db f 'recog']);


% ------ Pre-Processing ------
% read all training images
di = dir(fullfile(train.path,['*.' imType]));
if isempty(di)
    fprintf('Err: Couldn''t find any %s files in training path:\n\t%s\n',...
        imType,train.path)
    return
end

inputTrainFile = fullfile(train.path,'storedInputs.mat');
foundCache = exist(inputTrainFile,'file');
if foundCache & ~train.forceReread 
    % loads stored input images
    trainNow = train;
    load(inputTrainFile)
    trainNow.I = train.I;   % only update desired vars
    trainNow.classNameTrue = train.classNameTrue;    
    trainNow.classTrue = train.classTrue;
    train = trainNow;
    clear trainNow
else
    % read training images
    info = imfinfo(fullfile(train.path,di(1).name));
    Nx = info.Height;
    Ny = info.Width;
    M = length(di); % number of training images
    train.I = uint8(zeros(Nx,Ny,M)); % init for speed
    numOfClass = 0; % initialize the number of classes 
    classBoundary = 0;
    prevIndex = 1;
    for i = 1:M % for each training image
        % Read and store class
        [I,mp] = imread(fullfile(train.path,di(i).name),imType);
        train.I(:,:,i) = I;     % image
        %train.I(:,:,i) = adapthisteq(I,'clipLimit',0.001);
        %train.I(:,:,i) = imadjust(I);
        %train.I(:,:,i) = histeq(I);
        train.classNameTrue{i} = di(i).name(1:end-4); % specific for my names
        train.classTrue(i) = i;
        
        % Group & calculate statistics for each class
        %   looks for new classes by 1st 2 letters in filename
        if i>1
            if ~strcmp(di(i-1).name(1:2),di(i).name(1:2)) % class bndry
                j = i;
                classBoundary = 1;
            end
        end
        if i==M % force class bndry for last image
            j = i+1;
            classBoundary = 1;
        end
        if classBoundary
            numOfClass = numOfClass + 1;
            trainClass.className{numOfClass} = ['class_' di(j-1).name(1:2)];
            trainClass.num(numOfClass) = j-prevIndex;
            X = double(reshape(...
                train.I(:,:,prevIndex:(j-1)),[Nx*Ny j-prevIndex]))./256;
            trainClass.mean{numOfClass} = mean(X,2);
            trainClass.classStartIndex(numOfClass) = prevIndex;
            trainClass.classEndIndex(numOfClass) = j-1;
            prevIndex = j;
            clear X;
            classBoundary = 0;
        end
    end
    
    % save cached images for next run
    save(inputTrainFile,'train','Nx','Ny','M','mp','trainClass')
end

% read all images to be recognized
di = dir(fullfile(recog.path,['*.' imType]));
if isempty(di)
    fprintf('Err: Couldn''t find any %s files in recog path:\n\t%s\n',...
        imType,recog.path)
    return
end
info = imfinfo(fullfile(recog.path,di(1).name));
if (info.Height ~= Nx) | (info.Width ~= Ny)
    error('recog images must be same size as training images')
end
M2 = length(di); % number of recog images
recog.I = uint8(zeros(Nx,Ny,M2)); % init for speed
for i = 1:M2
    [I,mp] = imread(fullfile(recog.path,di(i).name),imType);
    recog.I(:,:,i) = I;     % image
    %recog.I(:,:,i) = adapthisteq(I,'clipLimit',0.001);
    %recog.I(:,:,i) = imadjust(I);
    %recog.I(:,:,i) = histeq(I);
    recog.classNameTrue{i} = di(i).name(1:end-4); % specific for my naming
    switch dbType
        case 'ALAN'
            recog.classTrue{i} = ... % matches by filename
                strmatch(recog.classNameTrue{i}(1:end-1),train.classNameTrue); 
        case 'UMIST'
            recog.classTrue{i} = ... % matches by filename
                strmatch(recog.classNameTrue{i}(1:2),train.classNameTrue); 
    end
    recog.classEst(i) = NaN;
end

% plot input images
if plots.intermediateOn
    figure,montage(reshape(train.I,[Nx Ny 1 M]),mp),title('training images')
    if plots.savePlotsOn, saveas(gcf,'training_images','png'), end
    %could use >>unix('/sw/bin/montage -h')
    figure,montage(reshape(recog.I,[Nx Ny 1 M2]),mp),title('recog images')
    if plots.savePlotsOn, saveas(gcf,'recog_images','png'), end
end

% ------ Call Face Recog Alorithm ------
switch frAlgorithm
    case 'eigen'
        % -- Compute EigenFaces using "training" faces --
        % Outputs:
        %   P       => eigenfaces (eigenvectors)
        %   train.wt=> weights for each training face
        %   train.mean => average image
        [P,train] = computeEigenfaces(train,nEigFace,plots);
        
        % -- Classify the "recognition" faces --
        % Outputs:
        %   recog.classNameEst  => estimated class name
        %   recog.classEst      => estimated class number
        %   recog.isCorrectClass=> 1/0 for corr/incorr classification
        [recog] = classifyFaces(recog,train,P,threshFace,threshClass,plots);
        
    case 'fisher'
        % -- Compute EigenFaces (PCA) using "training" faces --
        %[P1,train1] = computeEigenfaces(train,nEigFace,plots);
        MpFish = M-length(trainClass.num);
        MpFish = min(MpFish,maxFishEigFace); % acb - impose face space limit
        [P1,train1] = computeEigenfaces(train,MpFish,plots);
        
        % -- Compute FisherFaces using "training" faces and "EigenFaces" --
        % Outputs:
        %   P2       => fisherfaces (eigenvectors)
        %   train2.wt => weights for each training face
        [P2,train2] = computeFisherfaces(train,trainClass,plots,P1);
        
        % -- Classify the "recognition" faces --
        [recog] = classifyFaces(recog,train2,P2,threshFace,threshClass,plots);
        
    otherwise
        error(['I don''t know face recognition algorithm "' ...
                frAlgorithm '"'])
end

% ------ Plot/Display Results ------

% Plot nice recog -> top matches figure with subplots!
if plots.finalOn
    %sc = [.5 .8]; % portion of screen [width height]  % [Left Bot Wid Hei]
    %figure('Units','Normalized','Position',[(1-sc)/2 sc],'Color',1*[1 1 1])
    nTopMatch = 3;
    nCh = 6;
    p = 1; % current subplot
    f = 0; % current subfigure
    nRowsMax = 7; % max input faces per subfigure
    nSubF = ceil(M2/nRowsMax);
    for i = 1:M2 % each recog image
        % create new subfigure when necessary
        if rem(i-1,nRowsMax) == 0
            f = f+1;
            hf(f) = subfigure(1,nSubF,f); set(hf,'Color',1*[1 1 1])
            p = 1;
        end
        
        % plot input image to be recognized
        subplot(nRowsMax,nTopMatch+2,p)
        imshow(recog.I(:,:,i),mp)
        if p==1, title('inputs'); end
        nc = min(nCh,length(recog.classNameTrue{i}));
        text(0,0,recog.classNameTrue{i}(1:nc),'FontSize',7,'Color','b')
        p = p+2;
        
        % plot closest matching images from training set
        [val,ndx] = sort(recog.euDis(:,i));
        for j = 1:nTopMatch
            subplot(nRowsMax,nTopMatch+2,p)
            imshow(train.I(:,:,ndx(j)),mp)
            if p==3 & j==1
                title(sprintf('top %d matches ...',nTopMatch));
            end
            if strncmp(train.classNameTrue{ndx(j)},recog.classNameTrue{i},2)
                c = 'b'; %[0 0 1];
            else
                c = 'r';
            end
            nc = min(nCh,length(train.classNameTrue{ndx(j)}));
            text(0,0,train.classNameTrue{ndx(j)}(1:nc),...
                'FontSize',7,'Color',c)
            p = p+1;
        end
    end
    
    if plots.savePlotsOn
        for i=1:f
            saveas(hf(i),sprintf('top3_matches_%d',i),'png')
        end
    end
end

save workspace_dump
disp('done')


% =======================================================================
function [P,train] = computeEigenfaces(train,Mp,plots)
% Outputs:
%   P       => eigenfaces
%   train.wt=> weights for each training face
%   train.mean => average image

% Find the image size, [Nx Ny], and the number of training images, M
[Nx Ny M] = size(train.I);

if Mp>M
    warning(sprintf(...
        ['Can''t use more principal comp than input imgaes!\n'...
            '  -> Using %d components.'],M))
    Mp = M;
end

% Use cached version of slow eigenvalue computations if found
svdFile = fullfile(train.path,'storedSvd.mat');
foundCache = exist(svdFile,'file');
if foundCache & ~train.forceNewSvd 
    % loads stored vars
    load(svdFile)
else
    % Compute EigenFaces using "training" faces
    % learn principal components from {x1, x2, ..., xn}
    % (1) find mean face, me, and
    %     differences from means, A
    X = double(reshape(train.I,[Nx*Ny M]))./256; % 1 column per face
    me = mean(X,2);
    A = X - repmat(me,[1 M]);
    clear X
    
    % (2) covariance matrix, S = A*A' (skip computing by using svd)
    % (3) partial eigenvalue decomposition S = U'*E*U
    [U,E,V] = svd(A,0); % singular val decomp much faster
    
    % (4) get sorted eigenvalues (diag of E) and eigenvectors (U)
    eigVals = diag(E);
    eigVecs = U;
    clear U V
    
    % store cache for future runs
    save(svdFile,'eigVecs','eigVals','me','A')
end
    
% (5) P' = [u1' u2' ... um'] % pick Mp principal components
P = eigVecs(:,1:Mp);        % ouput eigenfaces
lambda = eigVals(1:Mp);     % output weights

train.mean = me;

% Project each face in training set onto eigenfaces, storing weight
train.wt = P'*A;

% Reconstruct projected faces
R = P*train.wt + repmat(train.mean,[1 M]);

% Plot average face, eigenvals
if plots.intermediateOn % >> help truesize
    figure,imshow(reshape(train.mean,[Nx Ny])),title('avg face')
    if plots.savePlotsOn, saveas(gcf,'avg_face','png'), end
    figure,plot([1:length(eigVals)], eigVals,'x-'),title('\lambda strength')
    if plots.savePlotsOn, saveas(gcf,'eigval_strength','png'), end
end

% Plot eigenfaces
if plots.intermediateOn
    I = reshape(P,[Nx Ny 1 Mp]);
    for i = 1:Mp % scale for plot
        mx = max(P(:,i));
        mi = min(P(:,i));
        I(:,:,1,i) = (I(:,:,1,i)-mi)./(mx-mi);
    end 
    figure,montage(I),title('eigenfaces'); % eigenfaces
    if plots.savePlotsOn, saveas(gcf,'eigenfaces','png'), end
end
err = sum(eigVals(Mp+1:M).^2);

% Plot reconstructed images
if plots.intermediateOn
    I = reshape(R,[Nx Ny 1 M]);
    for i = 1:M % scale for plot
        mx = max(R(:,i));
        mi = min(R(:,i));
        I(:,:,1,i) = (I(:,:,1,i)-mi)./(mx-mi);
    end 
    figure,montage(I),title('reconst training images')
    if plots.savePlotsOn, saveas(gcf,'reconst_training_images','png'), end
end


% =======================================================================
function [P,train] = computeFisherfaces(train,trainClass,plots,P1)
% Outputs:
%   P       => fisherfaces
%   train.wt=> weights for each training face
%   train.mean => average image
% Summary:
%   The idea behind this approach is the maximize the ratio of
% between-class scatter to that of within-class scatter.

% Find the image size, [Nx Ny], and the number of training images, M
[Nx Ny M] = size(train.I);

% Find mean face, me, and
%     differences from means, A
X = double(reshape(train.I,[Nx*Ny M]))./256; % 1 column per face
me = mean(X,2);
A = X - repmat(me,[1 M]);

% Use cached version of slow scattermatrix computations if found
scatFile = fullfile(train.path,'storedScat.mat');
foundCache = exist(scatFile,'file');
numOfClass = length(trainClass.num);
if foundCache & ~train.forceNewScat 
    % loads stored vars
    load(scatFile)
else
    % Calculate the between-class scatter matrix, Sb
    %       and the within-class scatter matrix,  Sw
    prod = zeros(Nx*Ny);
    Sb = zeros(Nx*Ny);
    for i = 1:numOfClass
        row = trainClass.mean{i} - me;
        prod = row * row';
        Sb = Sb + prod;
    end
    
    Sw = zeros(Nx*Ny);
    for i = 1:numOfClass
        for j = (trainClass.classStartIndex(i)):(trainClass.classEndIndex(i))
            row = X(:,j) - trainClass.mean{i};
            prod = row * row';
            Sw = Sw + prod;
        end
    end
    clear prod row
    
    % store cache for future runs
    save(scatFile,'Sb','Sw')
end

% Use PCA to project into subspace
Sbb = P1.'*Sb*P1; 
Sww = P1.'*Sw*P1;
clear Sb Sw % save memory

% Current decomposition method: (from class)
% Find generalized eigenvalues & eigenvectors using eig(A,B)
[V,D] = eig(Sbb,Sww);

% Another possible method: (from class)
% 1. Note that we only care about the direction of Sw*W on m1-m2
% 2. Guess w = Sw^-1 * (m1-m2), then iterate ???

% One more possible method: (from Duda book)
% 1. Find the eigenvalues as the roots of the characteristic
%    polynomial:  
%       det(Sb - lambda(i)*Sw) = 0
% 2. Then solve for the eigenvectors w(i) directly using:
%       (Sb - lambda(i)*Sw)*w(i) = 0

% Extract eigenvalues and sort largest to smallest
Ds = diag(D);
[tmp,ndx] = sort(abs(Ds));
ndx = flipud(ndx);
% get sorted eigenvalues (diag of E) and 
% eigenvectors (project V back into full space using P1)
eigVals = Ds(ndx);
eigVecs = P1*V(:,ndx);
clear D Ds V % save a little memory

% Only keep numOfClass-1 weights

% Only keep numOfClass-1 weights, and
% Scale to make eigenvectors normalized => sum(P(:,1).^2)==1
Mp = numOfClass-1;      
lambda = eigVals(1:Mp); % output weights
P = eigVecs(:,1:Mp);    % ouput fisherfaces
P = P./repmat(sum(P.^2).^0.5,Nx*Ny,1); % normalize

train.mean = me;

% Project each face in training set onto fisherfaces, storing weight
train.wt = P.'*A;

% Reconstruct projected faces
R = P*train.wt + repmat(train.mean,[1 M]);

% Plot average face, eigenvals
if plots.intermediateOn % >> help truesize
    figure,plot([1:length(eigVals)], eigVals,'x-'),title('\lambda strength')
    if plots.savePlotsOn, saveas(gcf,'fish_eigval_strength','png'), end
end

% Plot fisherfaces
if plots.intermediateOn
    I = reshape(P,[Nx Ny 1 Mp]);
    for i = 1:Mp % scale for plot
        mx = max(P(:,i));
        mi = min(P(:,i));
        I(:,:,1,i) = (I(:,:,1,i)-mi)./(mx-mi);
    end 
    figure,montage(I),title('fisherfaces') % fisherfaces
    if plots.savePlotsOn, saveas(gcf,'fisherfaces','png'), end
end

% Plot reconstructed images
if plots.intermediateOn
    I = reshape(R,[Nx Ny 1 M]);
    for i = 1:M % scale for plot
        mx = max(R(:,i));
        mi = min(R(:,i));
        I(:,:,1,i) = (I(:,:,1,i)-mi)./(mx-mi);
    end 
    figure,montage(I),title('reconst training images')
    if plots.savePlotsOn, saveas(gcf,'fish_reconst_trainign_images','png')
    end
    
    %figure
    %imagesc(I(:,:,1,1))
    %set(gca,'Units','pixels','Position',[100 100 3*[Ny Nx]])
    %colormap gray
end


% =======================================================================
function [recog] = classifyFaces(recog,train,P,threshFace,threshClass,plots)
% Outputs:
%   recog.classNameEst  => estimated class name
%   recog.classEst      => estimated class number
%   recog.isCorrectClass=> 1 or 0 for correct/incorrect classification

% Find the image size, [Nx Ny], and the number of training images, M
[Nx Ny M] = size(train.I);
% Find the image size, [Nx Ny], and the number of recog images, M2
[Nx Ny M2] = size(recog.I);

% Init some values
Mp = length(P(1,:));
X2 = double(reshape(recog.I,[Nx*Ny M2]))./256; % 1 column per face
A2 = X2 - repmat(train.mean,[1 M2]);
% Project each face in recog set onto eigenfaces, storing weight
recog.wt = P'*A2;
% Reconstruct projected faces
R = P*recog.wt + repmat(train.mean,[1 M2]);

% Plot reconstructed images
if plots.intermediateOn
    I = reshape(R,[Nx Ny 1 M2]);
    for i = 1:M2 % scale for plot
        mx = max(R(:,i));
        mi = min(R(:,i));
        I(:,:,1,i) = (I(:,:,1,i)-mi)./(mx-mi);
    end 
    figure,montage(I),title('reconst recog images')
    if plots.savePlotsOn, saveas(gcf,'reconst_recog_images','png'), end
end

% Find euclidian distance from each recog face to each known face
recog.euDis = zeros(M,M2);
for i = 1:M2 % each recog face
    for j = 1:M % each known face class
        recog.euDis(j,i) = sqrt(sum((recog.wt(:,i) - train.wt(:,j)).^2));
    end
end

% Classifiy to Nearest-Neighbor with two thresholds:
%   threshFace => how close a euDis has to be to any face?
%   threshWho  => how close a euDis has to be to a training face to
%                 declare match
[minDis ndx] = min(recog.euDis);
    %recog.classNameTrue % truth
    %train.classNameTrue(ndx) % estimated classification
fprintf('Results of face recognition:\n')
for i = 1:M2 % each recog face
    if minDis(i) > threshFace
        recog.classNameEst{i} = 'NonFace';
    elseif minDis(i) > threshClass
        recog.classNameEst{i} = 'UnknownFace';
    else
        recog.classNameEst{i} = train.classNameTrue{ndx(i)};
        recog.classEst(i) = ndx(i);
    end
    
    recog.isCorrectClass(i) = ...
        any(recog.classEst(i) == recog.classTrue{i});
    fprintf('\trecognized %s as %s\n',...
        recog.classNameTrue{i},recog.classNameEst{i});
end

numCorrect = sum(recog.isCorrectClass);
fprintf('\t%d of %d (%d%%) faces correctly classified\n',...
    numCorrect,M2,round(numCorrect/M2*100));