function [ network, inputList, outputList ] = setIoByModule( network, nInput, nOutput )
% SETIOBYMODULE set input and output neurons in separated modules
%
%   [network, inputList, outputList] = setIoByModule(network, nInput, nOutput)
%   First, this function randomly determined in which modules input neurons
%   should be. Output neurons would be in the other modules. Then, we randomly
%   select nInput neurons from input modules and nOutput neurons from output
%   modules.


%   ---------
%   Yen-Nan Lin, NTHU, 2010-2014, Matlab 2012a

moduleId = modularity_dir(full(network.matrix))';
network.moduleId = moduleId;
uniModuleId = unique(moduleId);
nModule = numel(uniModuleId);
for iModule = 1:nModule
    tmpModuleId = uniModuleId(iModule);
    tmpIx = find(moduleId == tmpModuleId);
    neuronIdEachModule{tmpModuleId} = network.label(tmpIx);
end

isValid = false;
while isValid == false
    nModuleInput = randi(nModule - 1);
    nModuleOutput = nModule - nModuleInput;

    uniModuleId = uniModuleId(randperm(numel(uniModuleId)));
    inputModuleId = uniModuleId(1:nModuleInput);
    outputModuleId = uniModuleId((nModuleInput + 1):end);

    inputModuleNeuronId = [neuronIdEachModule{inputModuleId}];
    outputModuleNeuronId = [neuronIdEachModule{outputModuleId}];
    if (numel(inputModuleNeuronId) < nInput) || ...
        (numel(outputModuleNeuronId) < nOutput)
        continue;
    end

    inputModuleNeuronId = ...
        inputModuleNeuronId(randperm(numel(inputModuleNeuronId)));
    outputModuleNeuronId = ...
        outputModuleNeuronId(randperm(numel(outputModuleNeuronId)));
    inputList = inputModuleNeuronId(1:nInput);
    outputList = outputModuleNeuronId(1:nOutput);
    isValid = true;
end
network = sortByIO(network, [inputList, outputList]);
network.inputList = inputList;
network.outputList = outputList;
network.inputNumber = nInput;
network.outputNumber = nOutput;

function [ network ] = sortByIO( network, ioList )
for ix = 1:length(ioList)
    wantNode = ioList(ix);
    if isnumeric(ioList)
        wantNodeIndex = find(network.label == wantNode, 1);
    elseif ischar(ioList)
        error('Still cannot access string label');
    end
    network.matrix = swapCol(network.matrix, ix, wantNodeIndex);
    network.matrix = swapRow(network.matrix, ix, wantNodeIndex);
    network.label = swapCol(network.label, ix, wantNodeIndex);
    network.moduleId = swapCol(network.moduleId, ix, wantNodeIndex);
end

function [ matrix ] = swapCol( matrix, col1, col2 )
matrix(:, [col1, col2]) = matrix(:, [col2, col1]);

function [ matrix ] = swapRow( matrix, col1, col2 )
matrix([col1, col2], :) = matrix([col2, col1], :);