diff --git a/thoi/measures/gaussian_copula.py b/thoi/measures/gaussian_copula.py index a71b503..bffe84d 100644 --- a/thoi/measures/gaussian_copula.py +++ b/thoi/measures/gaussian_copula.py @@ -1,13 +1,16 @@ from typing import Optional, Callable from tqdm.autonotebook import tqdm +from functools import partial +import pandas as pd import scipy as sp import numpy as np import torch from torch.utils.data import DataLoader from ..dataset import CovarianceDataset +from ..collectors import batch_to_csv def gaussianCopula(X): @@ -142,14 +145,14 @@ def multi_order_measures(X: np.ndarray, pd.DataFrame: DataFrame containing computed metrics. """ + T, N = X.shape + max_order = N if max_order is None else max_order + if batch_aggregation is None: - batch_aggregation = lambda X: [x for x in X if x is not None] + batch_aggregation = pd.concat if batch_data_collector is None: - batch_data_collector = lambda *args: args - - T, N = X.shape - max_order = N if max_order is None else max_order + batch_data_collector = partial(batch_to_csv, N=N) assert max_order <= N, f"max_order must be lower or equal than N. {max_order} > {N})" assert min_order <= max_order, f"min_order must be lower or equal than max_order. {min_order} > {max_order}"