mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 12:32:40 +02:00
First draft of base symbolic object, compiling with symbolic mapping.
This commit is contained in:
parent
c2d3c82944
commit
583f3bef0a
6 changed files with 410 additions and 99 deletions
349
GPy/core/symbolic.py
Normal file
349
GPy/core/symbolic.py
Normal file
|
|
@ -0,0 +1,349 @@
|
|||
# Copyright (c) 2014, GPy authors (see AUTHORS.txt).
|
||||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||
|
||||
import sys
|
||||
from ..core.parameterization import Parameterized
|
||||
import numpy as np
|
||||
import sympy as sym
|
||||
from ..core.parameterization import Param
|
||||
from sympy.utilities.lambdify import lambdastr, _imp_namespace, _get_namespace
|
||||
from sympy.utilities.iterables import numbered_symbols
|
||||
from sympy import exp
|
||||
from scipy.special import gammaln, gamma, erf, erfc, erfcx, polygamma
|
||||
from GPy.util.functions import normcdf, normcdfln, logistic, logisticln
|
||||
|
||||
class Symbolic_core():
|
||||
"""
|
||||
Base model symbolic class.
|
||||
"""
|
||||
|
||||
def __init__(self, expression, cacheable, derivatives=None, param=None, func_modules=[]):
|
||||
# Base class init, do some basic derivatives etc.
|
||||
|
||||
# Func_modules sets up the right mapping for functions.
|
||||
self.func_modules = func_modules
|
||||
self.func_modules += [{'gamma':gamma,
|
||||
'gammaln':gammaln,
|
||||
'erf':erf, 'erfc':erfc,
|
||||
'erfcx':erfcx,
|
||||
'polygamma':polygamma,
|
||||
'normcdf':normcdf,
|
||||
'normcdfln':normcdfln,
|
||||
'logistic':logistic,
|
||||
'logisticln':logisticln},
|
||||
'numpy']
|
||||
|
||||
self.expressions = {}
|
||||
self.expressions['function'] = expression
|
||||
self.cacheable = cacheable
|
||||
|
||||
# pull the parameters and inputs out of the symbolic pdf
|
||||
self.variables = {}
|
||||
vars = [e for e in expression.atoms() if e.is_Symbol]
|
||||
|
||||
# inputs are assumed to be those things that are
|
||||
# cacheable. I.e. those things that aren't stored within the
|
||||
# object except as cached. For covariance functions this is X
|
||||
# and Z, for likelihoods F and for mapping functions X.
|
||||
self.cacheable_vars = [] # list of everything that's cacheable
|
||||
for var in cacheable:
|
||||
self.variables[var] = [e for e in vars if e.name.split('_')[0]==var.lower()]
|
||||
self.cacheable_vars += self.variables[var]
|
||||
for var in cacheable:
|
||||
if not self.variables[var]:
|
||||
raise ValueError('Variable ' + var + ' was specified as cacheable but is not in expression. Expected to find symbols of the form ' + var.lower() + '_0 to represent ' + var)
|
||||
|
||||
# things that aren't cacheable are assumed to be parameters.
|
||||
self.variables['theta'] = sorted([e for e in vars if not e in self.cacheable_vars],key=lambda e:e.name)
|
||||
|
||||
# these are arguments for computing derivatives.
|
||||
derivative_arguments = []
|
||||
if derivatives is not None:
|
||||
for derivative in derivatives:
|
||||
derivative_arguments += self.variables[derivative]
|
||||
|
||||
# Do symbolic work to compute derivatives.
|
||||
self.expressions['derivative'] = {theta.name : sym.diff(self.expressions['function'],theta) for theta in derivative_arguments}
|
||||
# Add parameters to the model.
|
||||
for theta in self.variables['theta']:
|
||||
val = 1.0
|
||||
# TODO: need to decide how to handle user passing values for the se parameter vectors.
|
||||
if param is not None:
|
||||
if param.has_key(theta.name):
|
||||
val = param[theta.name]
|
||||
# Add parameter.
|
||||
setattr(self, theta.name, Param(theta.name, val, None))
|
||||
self.add_parameters(getattr(self, theta.name))
|
||||
|
||||
self.namespace = [globals(), self.__dict__]
|
||||
self._gen_code()
|
||||
|
||||
def eval_parameters_changed(self):
|
||||
# TODO: place checks for inf/nan in here
|
||||
# do all the precomputation codes.
|
||||
for variable, code in sorted(self.code['params_change'].iteritems()):
|
||||
setattr(self, variable, eval(code, *self.namespace))
|
||||
self.eval_update_cache()
|
||||
|
||||
def eval_update_cache(self, X=None):
|
||||
# TODO: place checks for inf/nan in here
|
||||
if X is not None:
|
||||
for i, theta in enumerate(self.variables['X']):
|
||||
setattr(self, theta.name, X[:, i][:, None])
|
||||
|
||||
for variable, code in sorted(self.code['update_cache'].iteritems()):
|
||||
setattr(self, variable, eval(code, *self.namespace))
|
||||
|
||||
def eval_update_gradients(self, partial, X):
|
||||
# TODO: place checks for inf/nan in here
|
||||
for theta in self.variables['theta']:
|
||||
code = self.code['derivative'][theta.name]
|
||||
setattr(getattr(self, theta.name),
|
||||
'gradient',
|
||||
(partial*eval(code, *self.namespace)).sum())
|
||||
|
||||
def eval_gradients_X(self, partial, X):
|
||||
gradients_X = np.zeros_like(X)
|
||||
self.eval_update_cache(X)
|
||||
for i, theta in enumerate(self.variables['X']):
|
||||
code = self.code['derivative'][theta.name]
|
||||
gradients_X[:, i:i+1] = partial*eval(code, *self.namespace)
|
||||
return gradients_X
|
||||
|
||||
def eval_f(self, X):
|
||||
self.eval_update_cache(X)
|
||||
return eval(self.code['function'], *self.namespace)
|
||||
|
||||
def code_parameters_changed(self):
|
||||
# do all the precomputation codes.
|
||||
lcode = ''
|
||||
for variable, code in sorted(self.code['params_change'].iteritems()):
|
||||
lcode += variable + ' = ' + self._print_code(code) + '\n'
|
||||
return lcode
|
||||
|
||||
def code_update_cache(self):
|
||||
lcode = 'if X is not None:\n'
|
||||
for i, theta in enumerate(self.variables['X']):
|
||||
lcode+= "\t" + self._print_code(theta.name) + ' = X[:, ' + str(i) + "][:, None]\n"
|
||||
|
||||
for variable, code in sorted(self.code['update_cache'].iteritems()):
|
||||
lcode+= self._print_code(variable) + ' = ' + self._print_code(code) + "\n"
|
||||
|
||||
return lcode
|
||||
|
||||
def code_update_gradients(self):
|
||||
lcode = ''
|
||||
for theta in self.variables['theta']:
|
||||
code = self.code['derivative'][theta.name]
|
||||
lcode += self._print_code(theta.name) + '.gradient = (partial*(' + self._print_code(code) + ')).sum()\n'
|
||||
return lcode
|
||||
|
||||
def code_gradients_X(self):
|
||||
lcode = 'gradients_X = np.zeros_like(X)\n'
|
||||
lcode += 'self.update_cache(X)\n'
|
||||
for i, theta in enumerate(self.variables['X']):
|
||||
code = self.code['derivative'][theta.name]
|
||||
lcode += 'gradients_X[:, ' + str(i) + ':' + str(i) + '+1] = partial*' + self._print_code(code) + '\n'
|
||||
lcode += 'return gradients_X\n'
|
||||
return lcode
|
||||
|
||||
def code_f(self):
|
||||
lcode = 'self.update_cache(X)\n'
|
||||
lcode += 'return ' + self._print_code(self.code['function'])
|
||||
return lcode
|
||||
|
||||
def stabilise(self):
|
||||
"""Stabilize the code in the model."""
|
||||
# this code is applied to all expressions in the model in an attempt to sabilize them.
|
||||
pass
|
||||
|
||||
def _gen_namespace(self, modules=None, use_imps=True):
|
||||
"""Gets the relevant namespaces for the given expressions."""
|
||||
from sympy.core.symbol import Symbol
|
||||
|
||||
# If the user hasn't specified any modules, use what is available.
|
||||
module_provided = True
|
||||
if modules is None:
|
||||
module_provided = False
|
||||
# Use either numpy (if available) or python.math where possible.
|
||||
# XXX: This leads to different behaviour on different systems and
|
||||
# might be the reason for irreproducible errors.
|
||||
modules = ["math", "mpmath", "sympy"]
|
||||
|
||||
try:
|
||||
_import("numpy")
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
modules.insert(1, "numpy")
|
||||
|
||||
|
||||
# Get the needed namespaces.
|
||||
namespaces = []
|
||||
# First find any function implementations
|
||||
if use_imps:
|
||||
for expr in self._expression_list:
|
||||
namespaces.append(_imp_namespace(expr))
|
||||
# Check for dict before iterating
|
||||
if isinstance(modules, (dict, str)) or not hasattr(modules, '__iter__'):
|
||||
namespaces.append(modules)
|
||||
else:
|
||||
namespaces += list(modules)
|
||||
# fill namespace with first having highest priority
|
||||
namespace = {}
|
||||
for m in namespaces[::-1]:
|
||||
buf = _get_namespace(m)
|
||||
namespace.update(buf)
|
||||
for expr in self._expression_list:
|
||||
if hasattr(expr, "atoms"):
|
||||
#Try if you can extract symbols from the expression.
|
||||
#Move on if expr.atoms in not implemented.
|
||||
syms = expr.atoms(Symbol)
|
||||
for term in syms:
|
||||
namespace.update({str(term): term})
|
||||
|
||||
|
||||
return namespace
|
||||
def update_expression_list(self):
|
||||
"""Extract a list of expressions from the dictionary of expressions."""
|
||||
self.expression_list = [] # code arrives in dictionary, but is passed in this list
|
||||
self.expression_keys = [] # Keep track of the dictionary keys.
|
||||
self.expression_order = [] # This may be unecessary. It's to give ordering for cse
|
||||
for key in self.expressions.keys():
|
||||
if key == 'function':
|
||||
self.expression_list.append(self.expressions[key])
|
||||
self.expression_keys.append([key])
|
||||
self.expression_order.append(1)
|
||||
self.code[key] = ''
|
||||
elif key[-10:] == 'derivative':
|
||||
self.code[key] = {}
|
||||
for dkey in self.expressions[key].keys():
|
||||
self.expression_list.append(self.expressions[key][dkey])
|
||||
self.expression_keys.append([key, dkey])
|
||||
if key[:-10] == 'first' or key[:-10] == '':
|
||||
self.expression_order.append(3) #sym.count_ops(self.expressions[key][dkey]))
|
||||
elif key[:-10] == 'second':
|
||||
self.expression_order.append(4) #sym.count_ops(self.expressions[key][dkey]))
|
||||
elif key[:-10] == 'third':
|
||||
self.expression_order.append(5) #sym.count_ops(self.expressions[key][dkey]))
|
||||
self.code[key][dkey] = ''
|
||||
else:
|
||||
self.expression_list.append(self.expressions[key])
|
||||
self.expression_keys.append([key])
|
||||
self.expression_order.append(2)
|
||||
self.code[key] = ''
|
||||
|
||||
# This step may be unecessary.
|
||||
# Not 100% sure if the sub expression elimination is order sensitive. This step orders the list with the 'function' code first and derivatives after.
|
||||
self.expression_order, self.expression_list, self.expression_keys = zip(*sorted(zip(self.expression_order, self.expression_list, self.expression_keys)))
|
||||
|
||||
|
||||
def _gen_code(self, cache_prefix = 'cache', sub_prefix = 'sub', prefix='XoXoXoX'):
|
||||
"""Generate code for the list of expressions provided using the common sub-expression eliminator to separate out portions that are computed multiple times."""
|
||||
# This is the dictionary that stores all the generated code.
|
||||
self.code = {}
|
||||
|
||||
# Convert the expressions to a list for common sub expression elimination
|
||||
# We should find the following type of expressions: 'function', 'derivative', 'second_derivative', 'third_derivative'.
|
||||
self.update_expression_list()
|
||||
|
||||
# Apply any global stabilisation operations to expressions.
|
||||
self.stabilise()
|
||||
|
||||
# Helper functions to get data in and out of dictionaries.
|
||||
# this code from http://stackoverflow.com/questions/14692690/access-python-nested-dictionary-items-via-a-list-of-keys
|
||||
def getFromDict(dataDict, mapList):
|
||||
return reduce(lambda d, k: d[k], mapList, dataDict)
|
||||
def setInDict(dataDict, mapList, value):
|
||||
getFromDict(dataDict, mapList[:-1])[mapList[-1]] = value
|
||||
|
||||
|
||||
# Do the common sub expression elimination
|
||||
subexpressions, expression_substituted_list = sym.cse(self.expression_list, numbered_symbols(prefix=prefix))
|
||||
cacheable_list = []
|
||||
|
||||
# Sort out any expression that's dependent on something that scales with data size (these are listed in cacheable).
|
||||
self.expressions['params_change'] = []
|
||||
self.expressions['update_cache'] = []
|
||||
cache_expressions_list = []
|
||||
sub_expression_list = []
|
||||
for expr in subexpressions:
|
||||
arg_list = [e for e in expr[1].atoms() if e.is_Symbol]
|
||||
cacheable_symbols = [e for e in arg_list if e in cacheable_list or e in self.cacheable_vars]
|
||||
if cacheable_symbols:
|
||||
self.expressions['update_cache'].append((expr[0].name, self._expr2code(arg_list, expr[1])))
|
||||
# list which ensures dependencies are cacheable.
|
||||
cacheable_list.append(expr[0])
|
||||
cache_expressions_list.append(expr[0].name)
|
||||
else:
|
||||
self.expressions['params_change'].append((expr[0].name, self._expr2code(arg_list, expr[1])))
|
||||
sub_expression_list.append(expr[0].name)
|
||||
|
||||
# Replace original code with code including subexpressions.
|
||||
for expr, keys in zip(expression_substituted_list, self.expression_keys):
|
||||
arg_list = [e for e in expr.atoms() if e.is_Symbol]
|
||||
setInDict(self.code, keys, self._expr2code(arg_list, expr))
|
||||
setInDict(self.expressions, keys, expr)
|
||||
|
||||
# Create variable names for cache and sub expression portions
|
||||
cache_dict = {}
|
||||
self.variables[cache_prefix] = []
|
||||
for i, sub in enumerate(cache_expressions_list):
|
||||
name = cache_prefix + str(i)
|
||||
cache_dict[sub] = name
|
||||
self.variables[cache_prefix].append(sym.var(name))
|
||||
|
||||
sub_dict = {}
|
||||
self.variables[sub_prefix] = []
|
||||
for i, sub in enumerate(sub_expression_list):
|
||||
name = sub_prefix + str(i)
|
||||
sub_dict[sub] = name
|
||||
self.variables[sub_prefix].append(sym.var(name))
|
||||
|
||||
# Replace sub expressions in main code with either cacheN or subN.
|
||||
for key, val in cache_dict.iteritems():
|
||||
for keys in self.expression_keys:
|
||||
setInDict(self.code, keys,
|
||||
getFromDict(self.code,keys).replace(key, val))
|
||||
|
||||
for key, val in sub_dict.iteritems():
|
||||
for keys in self.expression_keys:
|
||||
setInDict(self.code, keys,
|
||||
getFromDict(self.code,keys).replace(key, val))
|
||||
|
||||
# Set up precompute code as either cacheN or subN.
|
||||
self.code['update_cache'] = {}
|
||||
for key, val in self.expressions['update_cache']:
|
||||
expr = val
|
||||
for key2, val2 in cache_dict.iteritems():
|
||||
expr = expr.replace(key2, val2)
|
||||
for key2, val2 in sub_dict.iteritems():
|
||||
expr = expr.replace(key2, val2)
|
||||
self.code['update_cache'][cache_dict[key]] = expr
|
||||
|
||||
self.expressions['update_cache'] = dict(self.expressions['update_cache'])
|
||||
self.code['params_change'] = {}
|
||||
for key, val in self.expressions['params_change']:
|
||||
expr = val
|
||||
for key2, val2 in cache_dict.iteritems():
|
||||
expr = expr.replace(key2, val2)
|
||||
for key2, val2 in sub_dict.iteritems():
|
||||
expr = expr.replace(key2, val2)
|
||||
self.code['params_change'][sub_dict[key]] = expr
|
||||
self.expressions['params_change'] = dict(self.expressions['params_change'])
|
||||
|
||||
def _expr2code(self, arg_list, expr):
|
||||
"""Convert the given symbolic expression into code."""
|
||||
code = lambdastr(arg_list, expr)
|
||||
function_code = code.split(':')[1].strip()
|
||||
#for arg in arg_list:
|
||||
# function_code = function_code.replace(arg.name, 'self.'+arg.name)
|
||||
|
||||
return function_code
|
||||
|
||||
def _print_code(self, code):
|
||||
"""Prepare code for string writing."""
|
||||
for key in self.variables.keys():
|
||||
for arg in self.variables[key]:
|
||||
code = code.replace(arg.name, 'self.'+arg.name)
|
||||
return code
|
||||
|
|
@ -169,7 +169,7 @@ class Symbolic(Kern):
|
|||
def gradients_X(self, dL_dK, X, X2=None):
|
||||
#if self._X is None or X.base is not self._X.base or X2 is not None:
|
||||
self._K_computations(X, X2)
|
||||
gradients_X = np.zeros((X.shape[0], X.shape[1]))
|
||||
gradients_X = np.zeros_like(X)
|
||||
for i, x in enumerate(self._sym_x):
|
||||
gf = self._K_derivatives_code[x.name]
|
||||
gradients_X[:, i] = (gf(**self._arguments)*dL_dK).sum(1)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ try:
|
|||
import sympy as sym
|
||||
sympy_available=True
|
||||
from sympy.utilities.lambdify import lambdify
|
||||
from GPy.util.symbolic import stabilise
|
||||
except ImportError:
|
||||
sympy_available=False
|
||||
|
||||
|
|
|
|||
56
GPy/mappings/symbolic.py
Normal file
56
GPy/mappings/symbolic.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright (c) 2014 GPy Authors
|
||||
# Licensed under the BSD 3-clause license (see LICENSE.txt)
|
||||
|
||||
import sympy as sym
|
||||
from ..core.mapping import Mapping, Bijective_mapping
|
||||
from ..core.symbolic import Symbolic_core
|
||||
import numpy as np
|
||||
|
||||
class Symbolic(Mapping, Symbolic_core):
|
||||
"""
|
||||
Symbolic mapping
|
||||
|
||||
Mapping where the form of the mapping is provided by a sympy expression.
|
||||
|
||||
"""
|
||||
def __init__(self, input_dim, output_dim, f=None, name='symbolic', param=None, func_modules=[]):
|
||||
|
||||
|
||||
if f is None:
|
||||
raise ValueError, "You must provide an argument for the function."
|
||||
|
||||
Mapping.__init__(self, input_dim, output_dim, name=name)
|
||||
Symbolic_core.__init__(self, f, ['X'], derivatives = ['X', 'theta'], param=param, func_modules=func_modules)
|
||||
|
||||
self._initialize_cache()
|
||||
self.parameters_changed()
|
||||
|
||||
def _initialize_cache(self):
|
||||
self.x_0 = np.random.normal(size=(3, self.input_dim))
|
||||
|
||||
|
||||
def parameters_changed(self):
|
||||
self.eval_parameters_changed()
|
||||
|
||||
def update_cache(self, X):
|
||||
self.eval_update_cache(X)
|
||||
|
||||
def update_gradients(self, partial, X):
|
||||
self.eval_update_gradients(partial, X)
|
||||
|
||||
def gradients_X(self, partial, X):
|
||||
return self.eval_gradients_X(partial, X)
|
||||
|
||||
def f(self, X):
|
||||
"""
|
||||
"""
|
||||
return self.eval_f(X)
|
||||
|
||||
|
||||
def df_dX(self, X):
|
||||
"""
|
||||
"""
|
||||
pass
|
||||
|
||||
def df_dtheta(self, X):
|
||||
pass
|
||||
|
|
@ -304,8 +304,8 @@ def football_data(season='1314', data_set='football_data'):
|
|||
data_set_season = data_set + '_' + season
|
||||
data_resources[data_set_season] = copy.deepcopy(data_resources[data_set])
|
||||
data_resources[data_set_season]['urls'][0]+=season + '/'
|
||||
start_year = int(year[0:2])
|
||||
end_year = int(year[2:4])
|
||||
start_year = int(season[0:2])
|
||||
end_year = int(season[2:4])
|
||||
files = ['E0.csv', 'E1.csv', 'E2.csv', 'E3.csv']
|
||||
if start_year>4 and start_year < 93:
|
||||
files += ['EC.csv']
|
||||
|
|
|
|||
|
|
@ -1,100 +1,7 @@
|
|||
import sys
|
||||
import numpy as np
|
||||
import sympy as sym
|
||||
from sympy import Function, S, oo, I, cos, sin, asin, log, erf, pi, exp, sqrt, sign, gamma, polygamma
|
||||
from sympy.utilities.lambdify import lambdastr
|
||||
from sympy.utilities.iterables import numbered_symbols
|
||||
def stabilise(e):
|
||||
"""Attempt to find the most numerically stable form of an expression"""
|
||||
return e #sym.expand(e)
|
||||
|
||||
|
||||
def gen_code(expressions, cache_prefix = 'cache', sub_prefix = 'sub', prefix='XoXoXoX', cacheable=[]):
|
||||
"""Generate code for the list of expressions provided using the common sub-expression eliminator to separate out portions that are computed multiple times."""
|
||||
# First convert the expressions to a list.
|
||||
# We should find the following type of expressions: 'function', 'derivative', 'second_derivative', 'third_derivative'.
|
||||
|
||||
# Helper functions to get data in and out of dictionaries.
|
||||
# this code from http://stackoverflow.com/questions/14692690/access-python-nested-dictionary-items-via-a-list-of-keys
|
||||
def getFromDict(dataDict, mapList):
|
||||
return reduce(lambda d, k: d[k], mapList, dataDict)
|
||||
def setInDict(dataDict, mapList, value):
|
||||
getFromDict(dataDict, mapList[:-1])[mapList[-1]] = value
|
||||
|
||||
# This is the return dictionary that stores all the generated code.
|
||||
code = {}
|
||||
expression_list = []
|
||||
key_list = []
|
||||
order_list = []
|
||||
code['main'] = {}
|
||||
for key in expressions.keys():
|
||||
if key == 'function':
|
||||
expression_list.append(expressions[key])
|
||||
key_list.append([key])
|
||||
order_list.append(1)
|
||||
code['main'][key] = ''
|
||||
elif key[-10:] == 'derivative':
|
||||
code['main'][key] = {}
|
||||
for dkey in expressions[key].keys():
|
||||
expression_list.append(expressions[key][dkey])
|
||||
key_list.append([key, dkey])
|
||||
if key[:-10] == 'first' or key[:-10] == '':
|
||||
order_list.append(3) #sym.count_ops(expressions[key][dkey]))
|
||||
elif key[:-10] == 'second':
|
||||
order_list.append(4) #sym.count_ops(expressions[key][dkey]))
|
||||
elif key[:-10] == 'third':
|
||||
order_list.append(5) #sym.count_ops(expressions[key][dkey]))
|
||||
code['main'][key][dkey] = ''
|
||||
else:
|
||||
expression_list.append(expressions[key])
|
||||
key_list.append([key])
|
||||
order_list.append(2)
|
||||
code['main'][key] = ''
|
||||
|
||||
# This step may be unecessary.
|
||||
# Not 100% sure if the sub expression elimination is order sensitive. This step orders the list with the 'function' code first and derivatives after.
|
||||
order_list, expression_list, key_list = zip(*sorted(zip(order_list, expression_list, key_list)))
|
||||
|
||||
print expression_list
|
||||
|
||||
subexpressions, expression_substituted_list = sym.cse(expression_list, numbered_symbols(prefix=prefix))
|
||||
cacheable_list = []
|
||||
# Create strings that lambda strings from the expressions.
|
||||
code['params_change'] = []
|
||||
code['cache'] = []
|
||||
for expr in subexpressions:
|
||||
arg_list = [e for e in expr[1].atoms() if e.is_Symbol]
|
||||
cacheable_symbols = [e for e in arg_list if e in cacheable_list or e in cacheable]
|
||||
if cacheable_symbols:
|
||||
code['cacheable'].append((expr[0],expr2code(arg_list, expr[1])))
|
||||
# list which ensures dependencies are cacheable.
|
||||
cacheable_list.append(expr[0])
|
||||
code['cacheexpressions'].append(expr[0])
|
||||
else:
|
||||
code['params_change'].append((expr[0],expr2code(arg_list, expr[1])))
|
||||
code['subexpressions'].append(expr[0])
|
||||
|
||||
for expr, keys in zip(expression_substituted_list, key_list):
|
||||
arg_list = [e for e in expr.atoms() if e.is_Symbol]
|
||||
setInDict(code['main'], keys, expr2code(arg_list, expr))
|
||||
setInDict(expressions, keys, expr)
|
||||
|
||||
sub_dict = {}
|
||||
for i, sub in enumerate(code['cacheexpressions']):
|
||||
sub_dict[sub.name] = cache_prefix + str(i)
|
||||
for i, sub in enumerate(code['subexpressions']):
|
||||
sub_dict[sub.name] = sub_prefix + str(i)
|
||||
|
||||
|
||||
|
||||
return code
|
||||
|
||||
def expr2code(arg_list, expr):
|
||||
"""Convert the given symbolic expression into code."""
|
||||
code = lambdastr(arg_list, expr)
|
||||
function_code = code.split(':')[1]
|
||||
for arg in arg_list:
|
||||
function_code = function_code.replace(arg.name, 'self.'+arg.name)
|
||||
|
||||
return function_code
|
||||
|
||||
class logistic(Function):
|
||||
"""The logistic function as a symbolic function."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue