# -*- coding: utf-8 -*-
from __future__ import unicode_literals, print_function, division, absolute_import
from os import linesep
import tempfile
import logging
import dpath
import numpy as np
from openfisca_core import periods
from openfisca_core.commons import empty_clone, stringify_array, basestring_type, to_unicode
from openfisca_core.tracers import Tracer, TracingParameterNodeAtInstant
from openfisca_core.indexed_enums import Enum, EnumArray
log = logging.getLogger(__name__)
# Exceptions
class NaNCreationError(Exception):
pass
class CycleError(Exception):
pass
[docs]class Simulation(object):
"""
Represents a simulation, and handles the calculation logic
"""
debug = False
period = None
steps_count = 1
tax_benefit_system = None
trace = False
# ----- Simulation construction ----- #
def __init__(
self,
tax_benefit_system,
simulation_json = None,
debug = False,
period = None,
trace = False,
opt_out_cache = False,
memory_config = None,
):
"""
If a ``simulation_json`` is given, initialises a simulation from a JSON dictionary.
Note: This way of initialising a simulation, still under experimentation, aims at replacing the initialisation from `scenario.make_json_or_python_to_attributes`.
If no ``simulation_json`` is given, initialises an empty simulation.
"""
self.tax_benefit_system = tax_benefit_system
assert tax_benefit_system is not None
if period:
assert isinstance(period, periods.Period)
self.period = period
# To keep track of the values (formulas and periods) being calculated to detect circular definitions.
# See use in formulas.py.
# The data structure of requested_periods_by_variable_name is: {variable_name: [period1, period2]}
self.requested_periods_by_variable_name = {}
self.max_nb_cycles = None
self.debug = debug
self.trace = trace or self.debug
if self.trace:
self.tracer = Tracer()
else:
self.tracer = None
self.opt_out_cache = opt_out_cache
self.memory_config = memory_config
self._data_storage_dir = None
self.instantiate_entities(simulation_json)
def instantiate_entities(self, simulation_json):
if simulation_json:
check_type(simulation_json, dict, ['error'])
allowed_entities = set(entity_class.plural for entity_class in self.tax_benefit_system.entities)
unexpected_entities = [entity for entity in simulation_json if entity not in allowed_entities]
if unexpected_entities:
unexpected_entity = unexpected_entities[0]
raise SituationParsingError([unexpected_entity],
''.join([
"Some entities in the situation are not defined in the loaded tax and benefit system.",
"These entities are not found: {0}.",
"The defined entities are: {1}."]
)
.format(
', '.join(unexpected_entities),
', '.join(allowed_entities)
)
)
persons_json = simulation_json.get(self.tax_benefit_system.person_entity.plural, None)
if not persons_json:
raise SituationParsingError([self.tax_benefit_system.person_entity.plural],
'No {0} found. At least one {0} must be defined to run a simulation.'.format(self.tax_benefit_system.person_entity.key))
self.persons = self.tax_benefit_system.person_entity(self, persons_json)
else:
self.persons = self.tax_benefit_system.person_entity(self)
self.entities = {self.persons.key: self.persons}
setattr(self, self.persons.key, self.persons) # create shortcut simulation.person (for instance)
for entity_class in self.tax_benefit_system.group_entities:
if simulation_json:
entities_json = simulation_json.get(entity_class.plural)
entities = entity_class(self, entities_json or {})
else:
entities = entity_class(self)
self.entities[entity_class.key] = entities
setattr(self, entity_class.key, entities) # create shortcut simulation.household (for instance)
@property
def data_storage_dir(self):
"""
Temporary folder used to store intermediate calculation data in case the memory is saturated
"""
if self._data_storage_dir is None:
self._data_storage_dir = tempfile.mkdtemp(prefix = "openfisca_")
log.warn((
"Intermediate results will be stored on disk in {} in case of memory overflow. "
"You should remove this directory once you're done with your simulation."
).format(self._data_storage_dir).encode('utf-8'))
return self._data_storage_dir
# ----- Calculation methods ----- #
[docs] def calculate(self, variable_name, period, **parameters):
"""
Calculate the variable ``variable_name`` for the period ``period``, using the variable formula if it exists.
:returns: A numpy array containing the result of the calculation
"""
entity = self.get_variable_entity(variable_name)
holder = entity.get_holder(variable_name)
variable = self.tax_benefit_system.get_variable(variable_name)
if period is not None and not isinstance(period, periods.Period):
period = periods.period(period)
if self.trace:
self.tracer.record_calculation_start(variable.name, period, **parameters)
self._check_period_consistency(period, variable)
extra_params = parameters.get('extra_params', ())
# First look for a value already cached
cached_array = holder.get_array(period, extra_params)
if cached_array is not None:
if self.trace:
self.tracer.record_calculation_end(variable.name, period, cached_array, **parameters)
return cached_array
max_nb_cycles = parameters.get('max_nb_cycles')
if max_nb_cycles is not None:
self.max_nb_cycles = max_nb_cycles
# First, try to run a formula
array = self._run_formula(variable, entity, period, extra_params, max_nb_cycles)
# If no result, try a base function
if array is None and variable.base_function:
array = variable.base_function(holder, period, *extra_params)
# If no result, use the default value
if array is None:
array = holder.default_array()
self._clean_cycle_detection_data(variable.name)
if max_nb_cycles is not None:
self.max_nb_cycles = None
holder.put_in_cache(array, period, extra_params)
if self.trace:
self.tracer.record_calculation_end(variable.name, period, array, **parameters)
return array
def calculate_add(self, variable_name, period, **parameters):
variable = self.tax_benefit_system.get_variable(variable_name)
if period is not None and not isinstance(period, periods.Period):
period = periods.period(period)
# Check that the requested period matches definition_period
if variable.definition_period == periods.YEAR and period.unit == periods.MONTH:
raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' can only be computed for year-long periods. You can use the DIVIDE option to get an estimate of {0} by dividing the yearly value by 12, or change the requested period to 'period.this_year'.".format(
variable.name,
period,
).encode('utf-8'))
if variable.definition_period not in [periods.MONTH, periods.YEAR]:
raise ValueError("Unable to sum constant variable '{}' over period {}: only variables defined monthly or yearly can be summed over time.".format(
variable.name,
period).encode('utf-8'))
return sum(
self.calculate(variable_name, sub_period, **parameters)
for sub_period in period.get_subperiods(variable.definition_period)
)
def calculate_divide(self, variable_name, period, **parameters):
variable = self.tax_benefit_system.get_variable(variable_name)
if period is not None and not isinstance(period, periods.Period):
period = periods.period(period)
# Check that the requested period matches definition_period
if variable.definition_period != periods.YEAR:
raise ValueError("Unable to divide the value of '{}' over time on period {}: only variables defined yearly can be divided over time.".format(
variable_name,
period).encode('utf-8'))
if period.size != 1:
raise ValueError("DIVIDE option can only be used for a one-year or a one-month requested period")
if period.unit == periods.MONTH:
computation_period = period.this_year
return self.calculate(variable_name, period = computation_period, **parameters) / 12.
elif period.unit == periods.YEAR:
return self.calculate(variable_name, period, **parameters)
raise ValueError("Unable to divide the value of '{}' to match period {}.".format(
variable_name,
period).encode('utf-8'))
def calculate_output(self, variable_name, period):
"""
Calculate the value of a variable using the ``calculate_output`` attribute of the variable.
"""
variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True)
if variable.calculate_output is None:
return self.calculate(variable_name, period)
return variable.calculate_output(self, variable_name, period)
def trace_parameters_at_instant(self, formula_period):
return TracingParameterNodeAtInstant(
self.tax_benefit_system.get_parameters_at_instant(formula_period),
self.tracer
)
def _run_formula(self, variable, entity, period, extra_params, max_nb_cycles):
"""
Find the ``variable`` formula for the given ``period`` if it exists, and apply it to ``entity``.
"""
formula = variable.get_formula(period)
if formula is None:
return None
if self.trace:
parameters_at = self.trace_parameters_at_instant
else:
parameters_at = self.tax_benefit_system.get_parameters_at_instant
try:
self._check_for_cycle(variable, period)
if formula.__code__.co_argcount == 2:
array = formula(entity, period)
else:
array = formula(entity, period, parameters_at, *extra_params)
except CycleError as error:
self._clean_cycle_detection_data(variable.name)
if max_nb_cycles is None:
if self.trace:
self.tracer.record_calculation_abortion(variable.name, period, extra_params = extra_params)
# Re-raise until reaching the first variable called with max_nb_cycles != None in the stack.
raise error
self.max_nb_cycles = None
return None
self._check_formula_result(array, variable, entity, period)
return self._cast_formula_result(array, variable)
def _check_period_consistency(self, period, variable):
"""
Check that a period matches the variable definition_period
"""
if variable.definition_period == periods.ETERNITY:
return # For variables which values are constant in time, all periods are accepted
if variable.definition_period == periods.MONTH and period.unit != periods.MONTH:
raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole month. You can use the ADD option to sum '{0}' over the requested period, or change the requested period to 'period.first_month'.".format(
variable.name,
period
).encode('utf-8'))
if variable.definition_period == periods.YEAR and period.unit != periods.YEAR:
raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole year. You can use the DIVIDE option to get an estimate of {0} by dividing the yearly value by 12, or change the requested period to 'period.this_year'.".format(
variable.name,
period
).encode('utf-8'))
if period.size != 1:
raise ValueError("Unable to compute variable '{0}' for period {1}: '{0}' must be computed for a whole {2}. You can use the ADD option to sum '{0}' over the requested period.".format(
variable.name,
period,
'month' if variable.definition_period == periods.MONTH else 'year'
).encode('utf-8'))
def _check_formula_result(self, value, variable, entity, period):
assert isinstance(value, np.ndarray), (linesep.join([
"You tried to compute the formula '{0}' for the period '{1}'.".format(variable.name, str(period)).encode('utf-8'),
"The formula '{0}@{1}' should return a Numpy array;".format(variable.name, str(period)).encode('utf-8'),
"instead it returned '{0}' of {1}.".format(value, type(value)).encode('utf-8'),
"Learn more about Numpy arrays and vectorial computing:",
"<http://openfisca.org/doc/coding-the-legislation/25_vectorial_computing.html.>"
]))
assert value.size == entity.count, \
"Function {}@{}<{}>() --> <{}>{} returns an array of size {}, but size {} is expected for {}".format(
variable.name, entity.key, str(period), str(period), stringify_array(value),
value.size, entity.count, entity.key).encode('utf-8')
if self.debug:
try:
# cf https://stackoverflow.com/questions/6736590/fast-check-for-nan-in-numpy
if np.isnan(np.min(value)):
nan_count = np.count_nonzero(np.isnan(value))
raise NaNCreationError("Function {}@{}<{}>() --> <{}>{} returns {} NaN value(s)".format(
variable.name, entity.key, str(period), str(period), stringify_array(value),
nan_count).encode('utf-8'))
except TypeError:
pass
def _cast_formula_result(self, value, variable):
if variable.value_type == Enum and not isinstance(value, EnumArray):
return variable.possible_values.encode(value)
if value.dtype != variable.dtype:
return value.astype(variable.dtype)
return value
# ----- Handle circular dependencies in a calculation ----- #
def _check_for_cycle(self, variable, period):
"""
Return a boolean telling if the current variable has already been called without being allowed by
the parameter max_nb_cycles of the calculate method.
"""
def get_error_message():
return "Circular definition detected on formula {}@{}. Formulas and periods involved: {}.".format(
variable.name,
period,
", ".join(sorted(set(
"{}@{}".format(variable_name, period2)
for variable_name, periods in requested_periods_by_variable_name.items()
for period2 in periods
))).encode('utf-8'),
)
requested_periods_by_variable_name = self.requested_periods_by_variable_name
variable_name = variable.name
if variable_name in requested_periods_by_variable_name:
# Make sure the formula doesn't call itself for the same period it is being called for.
# It would be a pure circular definition.
requested_periods = requested_periods_by_variable_name[variable_name]
assert period not in requested_periods and (variable.definition_period != periods.ETERNITY), get_error_message()
if self.max_nb_cycles is None or len(requested_periods) > self.max_nb_cycles:
message = get_error_message()
if self.max_nb_cycles is None:
message += ' Hint: use "max_nb_cycles = 0" to get a default value, or "= N" to allow N cycles.'
raise CycleError(message)
else:
requested_periods.append(period)
else:
requested_periods_by_variable_name[variable_name] = [period]
def _clean_cycle_detection_data(self, variable_name):
"""
When the value of a formula have been computed, remove the period from
requested_periods_by_variable_name[variable_name] and delete the latter if empty.
"""
requested_periods_by_variable_name = self.requested_periods_by_variable_name
if variable_name in requested_periods_by_variable_name:
requested_periods_by_variable_name[variable_name].pop()
if len(requested_periods_by_variable_name[variable_name]) == 0:
del requested_periods_by_variable_name[variable_name]
# ----- Methods to access stored values ----- #
[docs] def get_array(self, variable_name, period):
"""
Return the value of ``variable_name`` for ``period``, if this value is alreay in the cache (if it has been set as an input or previously calculated).
Unlike :any:`calculate`, this method *does not* trigger calculations and *does not* use any formula.
"""
if period is not None and not isinstance(period, periods.Period):
period = periods.period(period)
return self.get_holder(variable_name).get_array(period)
[docs] def get_holder(self, variable_name):
"""
Get the :any:`Holder` associated with the variable ``variable_name`` for the simulation
"""
return self.get_variable_entity(variable_name).get_holder(variable_name)
[docs] def get_memory_usage(self, variables = None):
"""
Get data about the virtual memory usage of the simulation
"""
result = dict(
total_nb_bytes = 0,
by_variable = {}
)
for entity in self.entities.values():
entity_memory_usage = entity.get_memory_usage(variables = variables)
result['total_nb_bytes'] += entity_memory_usage['total_nb_bytes']
result['by_variable'].update(entity_memory_usage['by_variable'])
return result
# ----- Misc ----- #
def set_input(self, variable, period, value):
"""
Set a variable's value for a given period
:param variable: the variable to be set
:param value: the input value for the variable
:param period: the period for which the value is setted
Example:
>>> set_input('age', [12, 14], '2018-04')
>>> get_array('age', '2018-04')
>>> [12, 14]
If a ``set_input`` property has been set for the variable, this method may accept inputs for periods not matching the ``definition_period`` of the variable. To read more about this, check the `documentation <https://openfisca.org/doc/coding-the-legislation/35_periods.html#automatically-process-variable-inputs-defined-for-periods-not-matching-the-definitionperiod>`_.
"""
self.get_holder(variable).set_input(period, value)
def get_variable_entity(self, variable_name):
variable = self.tax_benefit_system.get_variable(variable_name, check_existence = True)
return self.get_entity(variable.entity)
def get_entity(self, entity_type = None, plural = None):
if entity_type:
return self.entities[entity_type.key]
if plural:
return [entity for entity in self.entities.values() if entity.plural == plural][0]
def clone(self, debug = False, trace = False):
"""
Copy the simulation just enough to be able to run the copy without modifying the original simulation
"""
new = empty_clone(self)
new_dict = new.__dict__
for key, value in self.__dict__.items():
if key not in ('debug', 'trace', 'tracer'):
new_dict[key] = value
new.persons = self.persons.clone(new)
setattr(new, new.persons.key, new.persons)
new.entities = {new.persons.key: new.persons}
for entity_class in self.tax_benefit_system.group_entities:
entity = self.entities[entity_class.key].clone(new)
new.entities[entity.key] = entity
setattr(new, entity_class.key, entity) # create shortcut simulation.household (for instance)
if debug:
new_dict['debug'] = True
if trace:
new_dict['trace'] = True
if debug or trace:
if self.debug or self.trace:
new_dict['tracer'] = self.tracer.clone()
else:
new_dict['tracer'] = Tracer()
return new
def check_type(input, input_type, path = []):
json_type_map = {
dict: "Object",
list: "Array",
basestring_type: "String",
}
if not isinstance(input, input_type):
raise SituationParsingError(path,
"Invalid type: must be of type '{}'.".format(json_type_map[input_type]))
class SituationParsingError(Exception):
def __init__(self, path, message, code = None):
self.error = {}
dpath_path = '/'.join(path)
message = to_unicode(message)
message = message.strip(linesep).replace(linesep, ' ')
dpath.util.new(self.error, dpath_path, message)
self.code = code
Exception.__init__(self, str(self.error).encode('utf-8'))
def __str__(self):
return str(self.error)
def calculate_output_add(simulation, variable_name, period):
return simulation.calculate_add(variable_name, period)
def calculate_output_divide(simulation, variable_name, period):
return simulation.calculate_divide(variable_name, period)