mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-08 11:32:39 +02:00
Need to fix missing data in likelihoods.
This commit is contained in:
parent
38f6d6a911
commit
196732b83b
3 changed files with 196 additions and 175 deletions
|
|
@ -2,15 +2,21 @@
|
||||||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
import re
|
||||||
from ..core.parameterization import Parameterized
|
from ..core.parameterization import Parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sympy as sym
|
import sympy as sym
|
||||||
from ..core.parameterization import Param
|
from ..core.parameterization import Param
|
||||||
from sympy.utilities.lambdify import lambdastr, _imp_namespace, _get_namespace
|
from sympy.utilities.lambdify import lambdastr, _imp_namespace, _get_namespace
|
||||||
from sympy.utilities.iterables import numbered_symbols
|
from sympy.utilities.iterables import numbered_symbols
|
||||||
from sympy import exp
|
from numpy import exp
|
||||||
from scipy.special import gammaln, gamma, erf, erfc, erfcx, polygamma
|
from scipy.special import gammaln, gamma, erf, erfc, erfcx, polygamma
|
||||||
from GPy.util.functions import normcdf, normcdfln, logistic, logisticln
|
from GPy.util.functions import normcdf, normcdfln, logistic, logisticln
|
||||||
|
def getFromDict(dataDict, mapList):
|
||||||
|
return reduce(lambda d, k: d[k], mapList, dataDict)
|
||||||
|
|
||||||
|
def setInDict(dataDict, mapList, value):
|
||||||
|
getFromDict(dataDict, mapList[:-1])[mapList[-1]] = value
|
||||||
|
|
||||||
class Symbolic_core():
|
class Symbolic_core():
|
||||||
"""
|
"""
|
||||||
|
|
@ -21,24 +27,42 @@ class Symbolic_core():
|
||||||
# Base class init, do some basic derivatives etc.
|
# Base class init, do some basic derivatives etc.
|
||||||
|
|
||||||
# Func_modules sets up the right mapping for functions.
|
# Func_modules sets up the right mapping for functions.
|
||||||
self.func_modules = func_modules
|
func_modules += [{'gamma':gamma,
|
||||||
self.func_modules += [{'gamma':gamma,
|
'gammaln':gammaln,
|
||||||
'gammaln':gammaln,
|
'erf':erf, 'erfc':erfc,
|
||||||
'erf':erf, 'erfc':erfc,
|
'erfcx':erfcx,
|
||||||
'erfcx':erfcx,
|
'polygamma':polygamma,
|
||||||
'polygamma':polygamma,
|
'normcdf':normcdf,
|
||||||
'normcdf':normcdf,
|
'normcdfln':normcdfln,
|
||||||
'normcdfln':normcdfln,
|
'logistic':logistic,
|
||||||
'logistic':logistic,
|
'logisticln':logisticln},
|
||||||
'logisticln':logisticln},
|
'numpy']
|
||||||
'numpy']
|
|
||||||
|
|
||||||
self._set_expressions(expressions)
|
self._set_expressions(expressions)
|
||||||
self._set_variables(cacheable)
|
self._set_variables(cacheable)
|
||||||
self._set_derivatives(derivatives)
|
self._set_derivatives(derivatives)
|
||||||
self._set_parameters(parameters)
|
self._set_parameters(parameters)
|
||||||
self.namespace = [globals(), self.__dict__]
|
# Convert the expressions to a list for common sub expression elimination
|
||||||
|
# We should find the following type of expressions: 'function', 'derivative', 'second_derivative', 'third_derivative'.
|
||||||
|
self.update_expression_list()
|
||||||
|
|
||||||
|
# Apply any global stabilisation operations to expressions.
|
||||||
|
self.global_stabilize()
|
||||||
|
|
||||||
|
# Helper functions to get data in and out of dictionaries.
|
||||||
|
# this code from http://stackoverflow.com/questions/14692690/access-python-nested-dictionary-items-via-a-list-of-keys
|
||||||
|
|
||||||
|
self.extract_sub_expressions()
|
||||||
self._gen_code()
|
self._gen_code()
|
||||||
|
self._set_namespace(func_modules)
|
||||||
|
|
||||||
|
def _set_namespace(self, namespaces):
|
||||||
|
"""Set the name space for use when calling eval. This needs to contain all the relvant functions for mapping from symbolic python to the numerical python. It also contains variables, cached portions etc."""
|
||||||
|
self.namespace = {}
|
||||||
|
for m in namespaces[::-1]:
|
||||||
|
buf = _get_namespace(m)
|
||||||
|
self.namespace.update(buf)
|
||||||
|
self.namespace.update(self.__dict__)
|
||||||
|
|
||||||
def _set_expressions(self, expressions):
|
def _set_expressions(self, expressions):
|
||||||
"""Extract expressions and variables from the user provided expressions."""
|
"""Extract expressions and variables from the user provided expressions."""
|
||||||
|
|
@ -79,7 +103,7 @@ class Symbolic_core():
|
||||||
|
|
||||||
# Do symbolic work to compute derivatives.
|
# Do symbolic work to compute derivatives.
|
||||||
for key, func in self.expressions.items():
|
for key, func in self.expressions.items():
|
||||||
self.expressions[key]['derivative'] = {theta.name : sym.diff(func['function'],theta) for theta in derivative_arguments}
|
self.expressions[key]['derivative'] = {theta.name : self.stabilize(sym.diff(func['function'],theta)) for theta in derivative_arguments}
|
||||||
|
|
||||||
def _set_parameters(self, parameters):
|
def _set_parameters(self, parameters):
|
||||||
"""Add parameters to the model and initialize with given values."""
|
"""Add parameters to the model and initialize with given values."""
|
||||||
|
|
@ -92,32 +116,45 @@ class Symbolic_core():
|
||||||
# Add parameter.
|
# Add parameter.
|
||||||
|
|
||||||
self.add_parameters(Param(theta.name, val, None))
|
self.add_parameters(Param(theta.name, val, None))
|
||||||
#setattr(self, theta.name, )
|
#self._set_attribute(theta.name, )
|
||||||
|
|
||||||
def eval_parameters_changed(self):
|
def eval_parameters_changed(self):
|
||||||
# TODO: place checks for inf/nan in here
|
# TODO: place checks for inf/nan in here
|
||||||
# do all the precomputation codes.
|
# do all the precomputation codes.
|
||||||
for variable, code in sorted(self.code['parameters_change'].iteritems()):
|
|
||||||
setattr(self, variable, eval(code, *self.namespace))
|
|
||||||
self.eval_update_cache()
|
self.eval_update_cache()
|
||||||
|
|
||||||
def eval_update_cache(self, **kwargs):
|
def eval_update_cache(self, **kwargs):
|
||||||
# TODO: place checks for inf/nan in here
|
# TODO: place checks for inf/nan in here
|
||||||
# for all provided keywords
|
# for all provided keywords
|
||||||
for variable, value in kwargs.items():
|
|
||||||
|
for var in sorted(self.code['parameters_changed'].keys(), key=lambda x: int(re.findall(r'\d+$', x)[0])):
|
||||||
|
code = self.code['parameters_changed'][var]
|
||||||
|
self._set_attribute(var, eval(code, self.namespace))
|
||||||
|
|
||||||
|
for var, value in kwargs.items():
|
||||||
# update their cached values
|
# update their cached values
|
||||||
if value is not None:
|
if value is not None:
|
||||||
if variable == 'X' or variable == 'F' or variable == 'Mu':
|
if var == 'X' or var == 'F' or var == 'M':
|
||||||
for i, theta in enumerate(self.variables[variable]):
|
value = np.atleast_2d(value)
|
||||||
setattr(self, theta.name, value[:, i][:, None])
|
for i, theta in enumerate(self.variables[var]):
|
||||||
elif variable.name == 'Z':
|
self._set_attribute(theta.name, value[:, i][:, None])
|
||||||
for i, theta in enumerate(self.variables[variable]):
|
elif var == 'Y':
|
||||||
setattr(self, theta.name, value[:, i][None, :])
|
# Y values can be missing.
|
||||||
|
value = np.atleast_2d(value)
|
||||||
|
for i, theta in enumerate(self.variables[var]):
|
||||||
|
self._set_attribute('missing' + str(i), np.isnan(value[:, i]))
|
||||||
|
self._set_attribute(theta.name, value[:, i][:, None])
|
||||||
|
elif var == 'Z':
|
||||||
|
value = np.atleast_2d(value)
|
||||||
|
for i, theta in enumerate(self.variables[var]):
|
||||||
|
self._set_attribute(theta.name, value[:, i][None, :])
|
||||||
else:
|
else:
|
||||||
setattr(self, theta.name, value[:, i])
|
value = np.atleast_1d(value)
|
||||||
|
for i, theta in enumerate(self.variables[var]):
|
||||||
for variable, code in sorted(self.code['update_cache'].iteritems()):
|
self._set_attribute(theta.name, value[i])
|
||||||
setattr(self, variable, eval(code, *self.namespace))
|
for var in sorted(self.code['update_cache'].keys(), key=lambda x: int(re.findall(r'\d+$', x)[0])):
|
||||||
|
code = self.code['update_cache'][var]
|
||||||
|
self._set_attribute(var, eval(code, self.namespace))
|
||||||
|
|
||||||
def eval_update_gradients(self, function, partial, **kwargs):
|
def eval_update_gradients(self, function, partial, **kwargs):
|
||||||
# TODO: place checks for inf/nan in here
|
# TODO: place checks for inf/nan in here
|
||||||
|
|
@ -126,7 +163,7 @@ class Symbolic_core():
|
||||||
code = self.code[function]['derivative'][theta.name]
|
code = self.code[function]['derivative'][theta.name]
|
||||||
setattr(getattr(self, theta.name),
|
setattr(getattr(self, theta.name),
|
||||||
'gradient',
|
'gradient',
|
||||||
(partial*eval(code, *self.namespace)).sum())
|
(partial*eval(code, self.namespace)).sum())
|
||||||
|
|
||||||
def eval_gradients_X(self, function, partial, **kwargs):
|
def eval_gradients_X(self, function, partial, **kwargs):
|
||||||
if kwargs.has_key('X'):
|
if kwargs.has_key('X'):
|
||||||
|
|
@ -134,17 +171,17 @@ class Symbolic_core():
|
||||||
self.eval_update_cache(**kwargs)
|
self.eval_update_cache(**kwargs)
|
||||||
for i, theta in enumerate(self.variables['X']):
|
for i, theta in enumerate(self.variables['X']):
|
||||||
code = self.code[function]['derivative'][theta.name]
|
code = self.code[function]['derivative'][theta.name]
|
||||||
gradients_X[:, i:i+1] = partial*eval(code, *self.namespace)
|
gradients_X[:, i:i+1] = partial*eval(code, self.namespace)
|
||||||
return gradients_X
|
return gradients_X
|
||||||
|
|
||||||
def eval_function(self, function, **kwargs):
|
def eval_function(self, function, **kwargs):
|
||||||
self.eval_update_cache(**kwargs)
|
self.eval_update_cache(**kwargs)
|
||||||
return eval(self.code[function]['function'], *self.namespace)
|
return eval(self.code[function]['function'], self.namespace)
|
||||||
|
|
||||||
def code_parameters_changed(self):
|
def code_parameters_changed(self):
|
||||||
# do all the precomputation codes.
|
# do all the precomputation codes.
|
||||||
lcode = ''
|
lcode = ''
|
||||||
for variable, code in sorted(self.code['parameters_change'].iteritems()):
|
for variable, code in sorted(self.code['parameters_changed'].iteritems()):
|
||||||
lcode += self._print_code(variable) + ' = ' + self._print_code(code) + '\n'
|
lcode += self._print_code(variable) + ' = ' + self._print_code(code) + '\n'
|
||||||
return lcode
|
return lcode
|
||||||
|
|
||||||
|
|
@ -159,6 +196,7 @@ class Symbolic_core():
|
||||||
else:
|
else:
|
||||||
reorder = ''
|
reorder = ''
|
||||||
for i, theta in enumerate(self.variables[var]):
|
for i, theta in enumerate(self.variables[var]):
|
||||||
|
lcode+= "\t" + var + '= np.atleast_2d(' + var + ')'
|
||||||
lcode+= "\t" + self._print_code(theta.name) + ' = ' + var + '[:, ' + str(i) + "]" + reorder + "\n"
|
lcode+= "\t" + self._print_code(theta.name) + ' = ' + var + '[:, ' + str(i) + "]" + reorder + "\n"
|
||||||
|
|
||||||
for variable, code in sorted(self.code['update_cache'].iteritems()):
|
for variable, code in sorted(self.code['update_cache'].iteritems()):
|
||||||
|
|
@ -173,13 +211,15 @@ class Symbolic_core():
|
||||||
lcode += self._print_code(theta.name) + '.gradient = (partial*(' + self._print_code(code) + ')).sum()\n'
|
lcode += self._print_code(theta.name) + '.gradient = (partial*(' + self._print_code(code) + ')).sum()\n'
|
||||||
return lcode
|
return lcode
|
||||||
|
|
||||||
def code_gradients_X(self, function):
|
def code_gradients_cacheable(self, function, variable):
|
||||||
lcode = 'gradients_X = np.zeros_like(X)\n'
|
if variable not in self.cacheable:
|
||||||
|
raise RuntimeError, variable + ' must be a cacheable.'
|
||||||
|
lcode = 'gradients_' + variable + ' = np.zeros_like(' + variable + ')\n'
|
||||||
lcode += 'self.update_cache(' + ', '.join(self.cacheable) + ')\n'
|
lcode += 'self.update_cache(' + ', '.join(self.cacheable) + ')\n'
|
||||||
for i, theta in enumerate(self.variables['X']):
|
for i, theta in enumerate(self.variables[variable]):
|
||||||
code = self.code[function]['derivative'][theta.name]
|
code = self.code[function]['derivative'][theta.name]
|
||||||
lcode += 'gradients_X[:, ' + str(i) + ':' + str(i) + '+1] = partial*' + self._print_code(code) + '\n'
|
lcode += 'gradients_' + variable + '[:, ' + str(i) + ':' + str(i) + '+1] = partial*' + self._print_code(code) + '\n'
|
||||||
lcode += 'return gradients_X\n'
|
lcode += 'return gradients_' + variable + '\n'
|
||||||
return lcode
|
return lcode
|
||||||
|
|
||||||
def code_function(self, function):
|
def code_function(self, function):
|
||||||
|
|
@ -187,58 +227,21 @@ class Symbolic_core():
|
||||||
lcode += 'return ' + self._print_code(self.code[function]['function'])
|
lcode += 'return ' + self._print_code(self.code[function]['function'])
|
||||||
return lcode
|
return lcode
|
||||||
|
|
||||||
def stabilise(self):
|
def stabilize(self, expr):
|
||||||
"""Stabilize the code in the model."""
|
"""Stabilize the code in the model."""
|
||||||
# this code is applied to all expressions in the model in an attempt to sabilize them.
|
# this code is applied to expressions in the model in an attempt to sabilize them.
|
||||||
|
return expr
|
||||||
|
|
||||||
|
def global_stabilize(self):
|
||||||
|
"""Stabilize all code in the model."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _gen_namespace(self, modules=None, use_imps=True):
|
def _set_attribute(self, name, value):
|
||||||
"""Gets the relevant namespaces for the given expressions."""
|
"""Make sure namespace gets updated when setting attributes."""
|
||||||
from sympy.core.symbol import Symbol
|
setattr(self, name, value)
|
||||||
|
self.namespace.update({name: getattr(self, name)})
|
||||||
# If the user hasn't specified any modules, use what is available.
|
|
||||||
module_provided = True
|
|
||||||
if modules is None:
|
|
||||||
module_provided = False
|
|
||||||
# Use either numpy (if available) or python.math where possible.
|
|
||||||
# XXX: This leads to different behaviour on different systems and
|
|
||||||
# might be the reason for irreproducible errors.
|
|
||||||
modules = ["math", "mpmath", "sympy"]
|
|
||||||
|
|
||||||
try:
|
|
||||||
_import("numpy")
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
modules.insert(1, "numpy")
|
|
||||||
|
|
||||||
|
|
||||||
# Get the needed namespaces.
|
|
||||||
namespaces = []
|
|
||||||
# First find any function implementations
|
|
||||||
if use_imps:
|
|
||||||
for expr in self._expression_list:
|
|
||||||
namespaces.append(_imp_namespace(expr))
|
|
||||||
# Check for dict before iterating
|
|
||||||
if isinstance(modules, (dict, str)) or not hasattr(modules, '__iter__'):
|
|
||||||
namespaces.append(modules)
|
|
||||||
else:
|
|
||||||
namespaces += list(modules)
|
|
||||||
# fill namespace with first having highest priority
|
|
||||||
namespace = {}
|
|
||||||
for m in namespaces[::-1]:
|
|
||||||
buf = _get_namespace(m)
|
|
||||||
namespace.update(buf)
|
|
||||||
for expr in self._expression_list:
|
|
||||||
if hasattr(expr, "atoms"):
|
|
||||||
#Try if you can extract symbols from the expression.
|
|
||||||
#Move on if expr.atoms in not implemented.
|
|
||||||
syms = expr.atoms(Symbol)
|
|
||||||
for term in syms:
|
|
||||||
namespace.update({str(term): term})
|
|
||||||
|
|
||||||
|
|
||||||
return namespace
|
|
||||||
def update_expression_list(self):
|
def update_expression_list(self):
|
||||||
"""Extract a list of expressions from the dictionary of expressions."""
|
"""Extract a list of expressions from the dictionary of expressions."""
|
||||||
self.expression_list = [] # code arrives in dictionary, but is passed in this list
|
self.expression_list = [] # code arrives in dictionary, but is passed in this list
|
||||||
|
|
@ -250,123 +253,141 @@ class Symbolic_core():
|
||||||
self.expression_list.append(texpressions)
|
self.expression_list.append(texpressions)
|
||||||
self.expression_keys.append([fname, type])
|
self.expression_keys.append([fname, type])
|
||||||
self.expression_order.append(1)
|
self.expression_order.append(1)
|
||||||
self.code[fname] = {type: ''}
|
|
||||||
elif type[-10:] == 'derivative':
|
elif type[-10:] == 'derivative':
|
||||||
self.code[fname] = {type:{}}
|
|
||||||
for dtype, expression in texpressions.items():
|
for dtype, expression in texpressions.items():
|
||||||
self.expression_list.append(expression)
|
self.expression_list.append(expression)
|
||||||
self.expression_keys.append([fname, type, dtype])
|
self.expression_keys.append([fname, type, dtype])
|
||||||
if type[:-10] == 'first' or type[:-10] == '':
|
if type[:-10] == 'first_' or type[:-10] == '':
|
||||||
self.expression_order.append(3) #sym.count_ops(self.expressions[type][dtype]))
|
self.expression_order.append(3) #sym.count_ops(self.expressions[type][dtype]))
|
||||||
elif type[:-10] == 'second':
|
elif type[:-10] == 'second_':
|
||||||
self.expression_order.append(4) #sym.count_ops(self.expressions[type][dtype]))
|
self.expression_order.append(4) #sym.count_ops(self.expressions[type][dtype]))
|
||||||
elif type[:-10] == 'third':
|
elif type[:-10] == 'third_':
|
||||||
self.expression_order.append(5) #sym.count_ops(self.expressions[type][dtype]))
|
self.expression_order.append(5) #sym.count_ops(self.expressions[type][dtype]))
|
||||||
self.code[fname][type][dtype] = ''
|
|
||||||
else:
|
else:
|
||||||
self.expression_list.append(fexpressions[type])
|
self.expression_list.append(fexpressions[type])
|
||||||
self.expression_keys.append([fname, type])
|
self.expression_keys.append([fname, type])
|
||||||
self.expression_order.append(2)
|
self.expression_order.append(2)
|
||||||
self.code[fname][type] = ''
|
|
||||||
|
|
||||||
# This step may be unecessary.
|
# This step may be unecessary.
|
||||||
# Not 100% sure if the sub expression elimination is order sensitive. This step orders the list with the 'function' code first and derivatives after.
|
# Not 100% sure if the sub expression elimination is order sensitive. This step orders the list with the 'function' code first and derivatives after.
|
||||||
self.expression_order, self.expression_list, self.expression_keys = zip(*sorted(zip(self.expression_order, self.expression_list, self.expression_keys)))
|
self.expression_order, self.expression_list, self.expression_keys = zip(*sorted(zip(self.expression_order, self.expression_list, self.expression_keys)))
|
||||||
|
|
||||||
|
def extract_sub_expressions(self, cache_prefix='cache', sub_prefix='sub', prefix='XoXoXoX'):
|
||||||
|
# Do the common sub expression elimination.
|
||||||
|
common_sub_expressions, expression_substituted_list = sym.cse(self.expression_list, numbered_symbols(prefix=prefix))
|
||||||
|
|
||||||
def _gen_code(self, cache_prefix = 'cache', sub_prefix = 'sub', prefix='XoXoXoX'):
|
self.variables[cache_prefix] = []
|
||||||
"""Generate code for the list of expressions provided using the common sub-expression eliminator to separate out portions that are computed multiple times."""
|
self.variables[sub_prefix] = []
|
||||||
# This is the dictionary that stores all the generated code.
|
|
||||||
self.code = {}
|
|
||||||
|
|
||||||
# Convert the expressions to a list for common sub expression elimination
|
# Create dictionary of new sub expressions
|
||||||
# We should find the following type of expressions: 'function', 'derivative', 'second_derivative', 'third_derivative'.
|
sub_expression_dict = {}
|
||||||
self.update_expression_list()
|
for var, void in common_sub_expressions:
|
||||||
|
sub_expression_dict[var.name] = var
|
||||||
# Apply any global stabilisation operations to expressions.
|
|
||||||
self.stabilise()
|
|
||||||
|
|
||||||
# Helper functions to get data in and out of dictionaries.
|
|
||||||
# this code from http://stackoverflow.com/questions/14692690/access-python-nested-dictionary-items-via-a-list-of-keys
|
|
||||||
def getFromDict(dataDict, mapList):
|
|
||||||
return reduce(lambda d, k: d[k], mapList, dataDict)
|
|
||||||
def setInDict(dataDict, mapList, value):
|
|
||||||
getFromDict(dataDict, mapList[:-1])[mapList[-1]] = value
|
|
||||||
|
|
||||||
|
|
||||||
# Do the common sub expression elimination
|
|
||||||
subexpressions, expression_substituted_list = sym.cse(self.expression_list, numbered_symbols(prefix=prefix))
|
|
||||||
cacheable_list = []
|
|
||||||
|
|
||||||
# Sort out any expression that's dependent on something that scales with data size (these are listed in cacheable).
|
# Sort out any expression that's dependent on something that scales with data size (these are listed in cacheable).
|
||||||
self.expressions['parameters_change'] = []
|
cacheable_list = []
|
||||||
self.expressions['update_cache'] = []
|
params_change_list = []
|
||||||
cache_expressions_list = []
|
# common_sube_expressions contains a list of paired tuples with the new variable and what it equals
|
||||||
sub_expression_list = []
|
for var, expr in common_sub_expressions:
|
||||||
for expr in subexpressions:
|
arg_list = [e for e in expr.atoms() if e.is_Symbol]
|
||||||
arg_list = [e for e in expr[1].atoms() if e.is_Symbol]
|
# List any cacheable dependencies of the sub-expression
|
||||||
cacheable_symbols = [e for e in arg_list if e in cacheable_list or e in self.cacheable_vars]
|
cacheable_symbols = [e for e in arg_list if e in cacheable_list or e in self.cacheable_vars]
|
||||||
if cacheable_symbols:
|
if cacheable_symbols:
|
||||||
self.expressions['update_cache'].append((expr[0].name, self._expr2code(arg_list, expr[1])))
|
|
||||||
# list which ensures dependencies are cacheable.
|
# list which ensures dependencies are cacheable.
|
||||||
cacheable_list.append(expr[0])
|
cacheable_list.append(var)
|
||||||
cache_expressions_list.append(expr[0].name)
|
|
||||||
else:
|
else:
|
||||||
self.expressions['parameters_change'].append((expr[0].name, self._expr2code(arg_list, expr[1])))
|
params_change_list.append(var)
|
||||||
sub_expression_list.append(expr[0].name)
|
|
||||||
|
replace_dict = {}
|
||||||
|
for i, expr in enumerate(cacheable_list):
|
||||||
|
sym_var = sym.var(cache_prefix + str(i))
|
||||||
|
self.variables[cache_prefix].append(sym_var)
|
||||||
|
replace_dict[expr.name] = sym_var
|
||||||
|
|
||||||
|
for i, expr in enumerate(params_change_list):
|
||||||
|
sym_var = sym.var(sub_prefix + str(i))
|
||||||
|
self.variables[sub_prefix].append(sym_var)
|
||||||
|
replace_dict[expr.name] = sym_var
|
||||||
|
|
||||||
|
for replace, void in common_sub_expressions:
|
||||||
|
for expr, keys in zip(expression_substituted_list, self.expression_keys):
|
||||||
|
setInDict(self.expressions, keys, expr.subs(replace, replace_dict[replace.name]))
|
||||||
|
for void, expr in common_sub_expressions:
|
||||||
|
expr = expr.subs(replace, replace_dict[replace.name])
|
||||||
|
|
||||||
# Replace original code with code including subexpressions.
|
# Replace original code with code including subexpressions.
|
||||||
for expr, keys in zip(expression_substituted_list, self.expression_keys):
|
for keys in self.expression_keys:
|
||||||
arg_list = [e for e in expr.atoms() if e.is_Symbol]
|
for replace, void in common_sub_expressions:
|
||||||
setInDict(self.code, keys, self._expr2code(arg_list, expr))
|
setInDict(self.expressions, keys, getFromDict(self.expressions, keys).subs(replace, replace_dict[replace.name]))
|
||||||
setInDict(self.expressions, keys, expr)
|
|
||||||
|
|
||||||
# Create variable names for cache and sub expression portions
|
self.expressions['parameters_changed'] = {}
|
||||||
cache_dict = {}
|
self.expressions['update_cache'] = {}
|
||||||
self.variables[cache_prefix] = []
|
for var, expr in common_sub_expressions:
|
||||||
for i, sub in enumerate(cache_expressions_list):
|
for replace, void in common_sub_expressions:
|
||||||
name = cache_prefix + str(i)
|
expr = expr.subs(replace, replace_dict[replace.name])
|
||||||
cache_dict[sub] = name
|
if var in cacheable_list:
|
||||||
self.variables[cache_prefix].append(sym.var(name))
|
self.expressions['update_cache'][replace_dict[var.name].name] = expr
|
||||||
|
else:
|
||||||
|
self.expressions['parameters_changed'][replace_dict[var.name].name] = expr
|
||||||
|
|
||||||
sub_dict = {}
|
|
||||||
self.variables[sub_prefix] = []
|
|
||||||
for i, sub in enumerate(sub_expression_list):
|
|
||||||
name = sub_prefix + str(i)
|
|
||||||
sub_dict[sub] = name
|
|
||||||
self.variables[sub_prefix].append(sym.var(name))
|
|
||||||
|
|
||||||
# Replace sub expressions in main code with either cacheN or subN.
|
|
||||||
for key, val in cache_dict.iteritems():
|
|
||||||
for keys in self.expression_keys:
|
|
||||||
setInDict(self.code, keys,
|
|
||||||
getFromDict(self.code,keys).replace(key, val))
|
|
||||||
|
|
||||||
for key, val in sub_dict.iteritems():
|
# for var, expr in common_sub_expressions:
|
||||||
for keys in self.expression_keys:
|
# if var in list(cacheable_list):
|
||||||
setInDict(self.code, keys,
|
# self.expressions['update_cache'].append({var.name: expr.subarg_list = [e for e in sub_expr_pair[1].atoms() if e.is_Symbol]
|
||||||
getFromDict(self.code,keys).replace(key, val))
|
# cacheable_symbols = [e for e in arg_list if e in cacheable_list or e in self.cacheable_vars]
|
||||||
|
# if cacheable_symbols:
|
||||||
|
# self.expressions['update_cache'].append((sub_expr_pair[0].name, self._expr2code(arg_list, sub_expr_pair[1])))
|
||||||
|
# # list which ensures dependencies are cacheable.
|
||||||
|
# cacheable_list.append(sub_expr_pair[0])
|
||||||
|
# cache_expressions_list.append(sub_expr_pair[0].name)
|
||||||
|
# else:
|
||||||
|
# self.expressions['parameters_change'].append((sub_expr_pair[0].name, self._expr2code(arg_list, sub_expr_pair[1])))
|
||||||
|
# sub_expression_list.append(sub_expr_pair[0].name)
|
||||||
|
|
||||||
# Set up precompute code as either cacheN or subN.
|
def _gen_code(self):
|
||||||
self.code['update_cache'] = {}
|
"""Generate code for the list of expressions provided using the common sub-expression eliminator to separate out portions that are computed multiple times."""
|
||||||
for key, val in self.expressions['update_cache']:
|
# This is the dictionary that stores all the generated code.
|
||||||
expr = val
|
|
||||||
for key2, val2 in cache_dict.iteritems():
|
|
||||||
expr = expr.replace(key2, val2)
|
|
||||||
for key2, val2 in sub_dict.iteritems():
|
|
||||||
expr = expr.replace(key2, val2)
|
|
||||||
self.code['update_cache'][cache_dict[key]] = expr
|
|
||||||
|
|
||||||
self.expressions['update_cache'] = dict(self.expressions['update_cache'])
|
self.code = {}
|
||||||
self.code['parameters_change'] = {}
|
def match_key(expr):
|
||||||
for key, val in self.expressions['parameters_change']:
|
if type(expr) is dict:
|
||||||
expr = val
|
code = {}
|
||||||
for key2, val2 in cache_dict.iteritems():
|
for key in expr.keys():
|
||||||
expr = expr.replace(key2, val2)
|
code[key] = match_key(expr[key])
|
||||||
for key2, val2 in sub_dict.iteritems():
|
else:
|
||||||
expr = expr.replace(key2, val2)
|
arg_list = [e for e in expr.atoms() if e.is_Symbol]
|
||||||
self.code['parameters_change'][sub_dict[key]] = expr
|
code = self._expr2code(arg_list, expr)
|
||||||
self.expressions['parameters_change'] = dict(self.expressions['parameters_change'])
|
return code
|
||||||
|
|
||||||
|
self.code = match_key(self.expressions)
|
||||||
|
|
||||||
|
|
||||||
|
# for keys in self.expression_keys:
|
||||||
|
# expr = getFromDict(self.expressions, keys)
|
||||||
|
# arg_list = [e for e in expr.atoms() if e.is_Symbol]
|
||||||
|
# setInDict(self.code, keys, self._expr2code(arg_list, expr))
|
||||||
|
|
||||||
|
# # Set up precompute code as either cacheN or subN.
|
||||||
|
# self.code['update_cache'] = {}
|
||||||
|
# for key, val in self.expressions['update_cache']:
|
||||||
|
# expr = val
|
||||||
|
# for key2, val2 in cache_dict.iteritems():
|
||||||
|
# expr = expr.replace(key2, val2.name)
|
||||||
|
# for key2, val2 in sub_dict.iteritems():
|
||||||
|
# expr = expr.replace(key2, val2.name)
|
||||||
|
# self.code['update_cache'][cache_dict[key]] = expr
|
||||||
|
|
||||||
|
# self.expressions['update_cache'] = dict(self.expressions['update_cache'])
|
||||||
|
# self.code['parameters_change'] = {}
|
||||||
|
# for key, val in self.expressions['parameters_change']:
|
||||||
|
# expr = val
|
||||||
|
# for key2, val2 in cache_dict.iteritems():
|
||||||
|
# expr = expr.replace(key2, val2.name)
|
||||||
|
# for key2, val2 in sub_dict.iteritems():
|
||||||
|
# expr = expr.replace(key2, val2.name)
|
||||||
|
# self.code['parameters_change'][sub_dict[key]] = expr
|
||||||
|
# self.expressions['parameters_change'] = dict(self.expressions['parameters_change'])
|
||||||
|
|
||||||
def _expr2code(self, arg_list, expr):
|
def _expr2code(self, arg_list, expr):
|
||||||
"""Convert the given symbolic expression into code."""
|
"""Convert the given symbolic expression into code."""
|
||||||
|
|
|
||||||
|
|
@ -14,9 +14,9 @@ except ImportError:
|
||||||
sympy_available=False
|
sympy_available=False
|
||||||
if sympy_available:
|
if sympy_available:
|
||||||
# These are likelihoods that rely on symbolic.
|
# These are likelihoods that rely on symbolic.
|
||||||
from symbolic import Symbolic
|
from symbolic2 import Symbolic
|
||||||
#from sstudent_t import SstudentT
|
#from sstudent_t import SstudentT
|
||||||
from negative_binomial import Negative_binomial
|
#from negative_binomial import Negative_binomial
|
||||||
#from skew_normal import Skew_normal
|
from skew_normal2 import Skew_normal
|
||||||
#from skew_exponential import Skew_exponential
|
#from skew_exponential import Skew_exponential
|
||||||
#from null_category import Null_category
|
#from null_category import Null_category
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ class Symbolic(Mapping, Symbolic_core):
|
||||||
self.parameters_changed()
|
self.parameters_changed()
|
||||||
|
|
||||||
def _initialize_cache(self):
|
def _initialize_cache(self):
|
||||||
self.x_0 = np.random.normal(size=(3, self.input_dim))
|
self._set_attribute('x_0', np.random.normal(size=(3, self.input_dim)))
|
||||||
|
|
||||||
|
|
||||||
def parameters_changed(self):
|
def parameters_changed(self):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue