Skip to content

Commit

Permalink
Add multi-periods to create_data_frame_by_entity
Browse files Browse the repository at this point in the history
  • Loading branch information
benjello committed Mar 17, 2023
1 parent f1ce096 commit ad7c2aa
Showing 1 changed file with 35 additions and 2 deletions.
37 changes: 35 additions & 2 deletions openfisca_survey_manager/scenarios.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def compute_pivot_table(self, aggfunc = 'mean', columns = None, difference = Fal
return data_frame.pivot_table(index = index, columns = columns, values = weight_variable, aggfunc = 'sum')

def create_data_frame_by_entity(self, variables = None, expressions = None, filter_by = None, index = False,
period = None, use_baseline = False, merge = False):
period = None, periods = None, use_baseline = False, merge = False):
"""Create dataframe(s) of computed variable for every entity (eventually merged in a unique dataframe).
Args:
Expand All @@ -591,6 +591,39 @@ def create_data_frame_by_entity(self, variables = None, expressions = None, filt
assert simulation is not None
assert tax_benefit_system is not None

if periods is not None:
assert period is None, "periods and period canno't be simultaneously not None"

entities = set(
tax_benefit_system.get_variable(variable).entity.key
for variable in variables
)
variables_by_entity = {
entity: [
variable
for variable in variables
if tax_benefit_system.get_variable(variable).entity.key == entity
]
for entity in entities
}

return {
entity: pd.concat({
period: self.create_data_frame_by_entity(
variables = variables_by_entity[entity],
expressions = expressions,
filter_by = filter_by,
index = index,
period = period,
periods = None,
use_baseline = use_baseline,
merge = False,
)[entity]
for period in periods
}).rename_axis(["period", "index"]).reset_index(level = 0)
for entity in entities
}

if period is None:
period = simulation.period

Expand Down Expand Up @@ -1595,7 +1628,7 @@ def init_variable_in_entity(simulation, entity, variable_name, series, period):
series.isnull().sum(), series.notnull().sum(), variable_name))
log.debug('We convert these NaN values of variable {} to {} its default value'.format(
variable_name, variable.default_value))
series.fillna(variable.default_value, inplace = True)
series = series.fillna(variable.default_value)
assert series.notnull().all(), \
'There are {} NaN values for {} non NaN values in variable {}'.format(
series.isnull().sum(), series.notnull().sum(), variable_name)
Expand Down

0 comments on commit ad7c2aa

Please sign in to comment.