Skip to content

Commit

Permalink
[ENH] Improve Data fusion (#21)
Browse files Browse the repository at this point in the history
* improve data fusion

* bugfix data is not a cell anymore

* more typo :)

* remove data parameter from data_fusion (not used)

* refactor be_fusion_of_modalities

* update doc

* improve display

* final clean-up

* typo
  • Loading branch information
Edouard2laire authored Sep 28, 2024
1 parent c0c1f70 commit 7c561ae
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 85 deletions.
2 changes: 1 addition & 1 deletion best/cmem/solver/be_cmem_solver.m
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@

%% ===== Fuse modalities ===== %%

obj = be_fusion_of_modalities( [], obj, OPTIONS);
obj = be_fusion_of_modalities(obj, OPTIONS);

%% ===== Solve the MEM ===== %%

Expand Down
8 changes: 4 additions & 4 deletions best/mem/be_launch_mem.m
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
Data = [];

if ~isempty(OPTIONS.automatic.selected_samples)
Data = [Data;obj.data{1}(:,OPTIONS.automatic.selected_samples(1,:))];
Data = obj.data(:,OPTIONS.automatic.selected_samples(1,:));
elseif ~strcmp(OPTIONS.mandatory.pipeline,'wMEM')
Data = obj.data;
end
Expand Down Expand Up @@ -140,11 +140,11 @@
OPTIONS.automatic.wActivation = full(ImageSourceAmp);

elseif strcmp(OPTIONS.mandatory.pipeline, 'wMEM') && ~OPTIONS.wavelet.single_box
ImageGridAmp = zeros(obj.nb_dipoles, size(obj.data{1},2));
wav = zeros( nbSmp, size(obj.data{1},2) );
ImageGridAmp = zeros(obj.nb_dipoles, size(obj.data,2));
wav = zeros( nbSmp, size(obj.data,2) );

for ii = 1 : nbSmp
nbSmpTime = size(obj.data{1},2) ;
nbSmpTime = size(obj.data,2) ;
scale = OPTIONS.automatic.selected_samples(2,ii);
transl = OPTIONS.automatic.selected_samples(3,ii);
wav(ii, nbSmpTime/2^scale + transl ) = 1;
Expand Down
105 changes: 31 additions & 74 deletions best/misc/be_fusion_of_modalities.m
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
function [obj] = be_fusion_of_modalities(data, obj, OPTIONS)
%BE_FUSION_OF_MODALITIES fuses data and leadfields from EEG and MEG for
% multimodal sources estimation using MEM
function [obj] = be_fusion_of_modalities(obj, OPTIONS)
%BE_FUSION_OF_MODALITIES fuses data and leadfields from different modalities
% for multimodal sources estimation using MEM
%
% INPUTS:
% - data
% - obj
% - OPTIONS
%
% OUTPUTS:
% - OPTIONS
% - obj
% - obj
%
%% ==============================================
% Copyright (C) 2011 - LATIS Team
Expand All @@ -34,87 +32,46 @@
% -------------------------------------------------------------------------


if isempty(data)
data = {OPTIONS.automatic.Modality.data};
obj.data = data{1};

end
if isfield(OPTIONS.automatic.Modality(1),'idata')
idata = {OPTIONS.automatic.Modality.idata};
obj.idata = idata{1};
% Display information
if OPTIONS.optional.verbose && length(OPTIONS.mandatory.DataTypes) > 1
fprintf('%s, MULTIMODAL data ... %s found \n',OPTIONS.mandatory.pipeline, strjoin(OPTIONS.mandatory.DataTypes,', '));
elseif OPTIONS.optional.verbose && length(OPTIONS.mandatory.DataTypes) == 1
fprintf('%s, No multimodalities ... \n',OPTIONS.mandatory.pipeline);
end

if isfield( OPTIONS.automatic.Modality(1), 'gain' )
obj.gain = OPTIONS.automatic.Modality(1).gain;
% Concatenate data
if isfield(obj, 'data') % wavelet
data = vertcat(obj.data{:});
else % Time-series
data = vertcat(OPTIONS.automatic.Modality.data);
end
obj.data = data;

% If the covariance matrix is scale dependent, no multimodality possible
if size(OPTIONS.automatic.Modality(1).covariance,3) > 1
obj.noise_var = OPTIONS.automatic.Modality(1).covariance;
else
obj.noise_var = diag( OPTIONS.automatic.Modality(1).covariance );
% Concatenate idata(complex data) if present
if isfield(OPTIONS.automatic.Modality(1),'idata')
obj.idata = vertcat(OPTIONS.automatic.Modality.idata);
end

obj.baseline = OPTIONS.automatic.Modality(1).baseline;
obj.channels = OPTIONS.automatic.Modality(1).channels;
% Concatenate Gain
obj.gain = vertcat(OPTIONS.automatic.Modality.gain);

if length(OPTIONS.mandatory.DataTypes)>1 % fusion of modalities if requested
if OPTIONS.optional.verbose
fprintf('%s, MULTIMODAL data ... %s ',OPTIONS.mandatory.pipeline, OPTIONS.mandatory.DataTypes{1});
end
if strcmp(OPTIONS.mandatory.pipeline,'wMEM')
obj.data = data{1};
end
% Concatenate noise covariance
if size(OPTIONS.automatic.Modality(1).covariance,3) > 1 % we concatanate for each covariance matrix

obj.noise_var = OPTIONS.automatic.Modality(1).covariance;
for ii=2:length(OPTIONS.mandatory.DataTypes)

obj.data = [obj.data ; data{ii}];


if exist('idata', 'var'); obj.idata = [obj.idata; idata{ii}]; end
data{ii} = [];
idata{ii} = [];

if isfield( OPTIONS.automatic.Modality(1), 'gain' )
obj.gain = [obj.gain ; OPTIONS.automatic.Modality(ii).gain];
end
if size(OPTIONS.automatic.Modality(1).covariance,3) > 1 && OPTIONS.optional.baseline_shuffle
% we concatanate for each covariance matrix
tmp = [];
for ibaseline = 1:size(obj.noise_var,3)
tmp(:,:,ibaseline) = blkdiag(obj.noise_var(:,:,ibaseline),OPTIONS.automatic.Modality(ii).covariance(:,:,ibaseline));
end
obj.noise_var = tmp;

else
obj.noise_var = [obj.noise_var; diag(OPTIONS.automatic.Modality(ii).covariance)];
end

obj.baseline = [obj.baseline; OPTIONS.automatic.Modality(ii).baseline];
obj.channels = [obj.channels; OPTIONS.automatic.Modality(ii).channels];
if OPTIONS.optional.verbose
fprintf('... %s found ', OPTIONS.mandatory.DataTypes{ii})
tmp = [];
for ibaseline = 1:size(obj.noise_var,3)
tmp(:,:,ibaseline) = blkdiag(obj.noise_var(:,:,ibaseline),OPTIONS.automatic.Modality(ii).covariance(:,:,ibaseline));
end
obj.noise_var = tmp;
end

if strcmp(OPTIONS.mandatory.pipeline,'wMEM')
temp=obj.data;
obj.data = [];
obj.data{1} = temp;
end


else
if OPTIONS.optional.verbose
fprintf('%s, No multimodalities ...',OPTIONS.mandatory.pipeline);
end
end

% If the covariance matrix is scale dependent, noise_var is already a matrix
if size(OPTIONS.automatic.Modality(1).covariance,3)==1
obj.noise_var = diag( obj.noise_var );
obj.noise_var = blkdiag(OPTIONS.automatic.Modality.covariance);
end

if OPTIONS.optional.verbose, fprintf('\n'); end
% Concatenate baseline and channels
obj.baseline = vertcat(OPTIONS.automatic.Modality.baseline);
obj.channels = vertcat(OPTIONS.automatic.Modality.channels);

end
4 changes: 2 additions & 2 deletions best/rmem/main/be_main_rmem.m
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@


if strcmp(OPTIONS.mandatory.pipeline,'rMEM')
new_obj = be_fusion_of_modalities( [], obj, OPTIONS);
new_obj = be_fusion_of_modalities(obj, OPTIONS);
new_obj.noise_var = real(new_obj.noise_var);
[obj.ImageGridAmp, OPTIONS] = be_launch_mem(new_obj, OPTIONS);
% - imaginary part
if ~isempty(new_obj.idata)
new_obj = be_fusion_of_modalities( [], obj, OPTIONS);
new_obj = be_fusion_of_modalities(obj, OPTIONS);
if ~isreal( new_obj.noise_var )
new_obj.noise_var = imag(new_obj.noise_var);
end
Expand Down
4 changes: 2 additions & 2 deletions best/rmem/solver/be_rmem_solver.m
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,10 @@

% Retrieve ridges infos
if isempty(ii) && ~isempty(OPTIONS.automatic.rMEMfiles)
[obj] = be_fusion_of_modalities([], obj, OPTIONS);
[obj] = be_fusion_of_modalities(obj, OPTIONS);

elseif isempty(ii) && isempty(OPTIONS.automatic.rMEMfiles)
[obj] = be_fusion_of_modalities([], obj, OPTIONS);
[obj] = be_fusion_of_modalities(obj, OPTIONS);

% single precision
[OPTIONS] = be_switch_precision( OPTIONS, 'single' );
Expand Down
2 changes: 1 addition & 1 deletion best/rmem/wavelet/analytical_wavelet/be_ridgefilter.m
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@
% Fuse signals
idCH = 1 : size( OPTIONS.mandatory.Data,1);
if ~OPTIONS.automatic.stand_alone || OPTIONS.automatic.process
FUS = be_fusion_of_modalities( [], [], OPTIONS );
FUS = be_fusion_of_modalities([], OPTIONS );
OPTIONS.automatic.nb_channels
% Get study channel info
nChan = bst_get('ChannelForStudy', OPTIONS.automatic.iStudy);
Expand Down
2 changes: 1 addition & 1 deletion best/wmem/main/be_main_wmem.m
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
% -------------------------------------------------------------------------

if strcmp(OPTIONS.mandatory.pipeline,'wMEM')
obj = be_fusion_of_modalities(obj.data,obj,OPTIONS);
obj = be_fusion_of_modalities(obj, OPTIONS);
[obj.ImageGridAmp, OPTIONS] = be_launch_mem(obj, OPTIONS);

% si j=0, on remplace obj.data = obj.scaling_data, on corrige ? la matrice de variance covariance
Expand Down

0 comments on commit 7c561ae

Please sign in to comment.