More changes to symbolic

This commit is contained in:
Neil Lawrence 2014-04-11 16:35:41 +01:00
parent 41ef7f4c72
commit 9277695a6d
3 changed files with 87 additions and 30 deletions

View file

@ -34,7 +34,7 @@ class Mapping(Parameterized):
raise NotImplementedError
def df_dtheta(self, dL_df, X):
"""The gradient of the outputs of the multi-layer perceptron with respect to each of the parameters.
"""The gradient of the outputs of the mapping with respect to each of the parameters.
:param dL_df: gradient of the objective with respect to the function.
:type dL_df: ndarray (num_data x output_dim)
@ -50,7 +50,7 @@ class Mapping(Parameterized):
"""
Plots the mapping associated with the model.
- In one dimension, the function is plotted.
- In two dimsensions, a contour-plot shows the function
- In two dimensions, a contour-plot shows the function
- In higher dimensions, we've not implemented this yet !TODO!
Can plot only part of the data and part of the posterior functions
@ -65,6 +65,14 @@ class Mapping(Parameterized):
else:
raise NameError, "matplotlib package has not been imported."
class Bijective_mapping(Mapping):
"""This is a mapping that is bijective, i.e. you can go from X to f and also back from f to X. The inverse mapping is called g()."""
def __init__(self, input_dim, output_dim, name='bijective_mapping'):
super(Bijective_apping, self).__init__(name=name)
def g(self, f):
"""Inverse mapping from output domain of the function to the inputs."""
raise NotImplementedError
from model import Model

View file

@ -4,7 +4,7 @@
import itertools
import numpy
from parameter_core import OptimizationHandlable, adjust_name_for_printing
from array_core import ObsAr
from observable_array import ObsAr
###### printing
__constraints_name__ = "Constraint"

View file

@ -7,45 +7,94 @@ def stabilise(e):
return e #sym.expand(e)
def gen_code(expressions, prefix='sub'):
"""Generate code for the list of expressions provided using the common sub-expression eliminator."""
def gen_code(expressions, cache_prefix = 'cache', sub_prefix = 'sub', prefix='XoXoXoX', cacheable=[]):
"""Generate code for the list of expressions provided using the common sub-expression eliminator to separate out portions that are computed multiple times."""
# First convert the expressions to a list.
# We should find the following type of expressions: 'function', 'derivative', 'second_derivative', 'third_derivative'.
# Helper functions to get data in and out of dictionaries.
# this code from http://stackoverflow.com/questions/14692690/access-python-nested-dictionary-items-via-a-list-of-keys
def getFromDict(dataDict, mapList):
return reduce(lambda d, k: d[k], mapList, dataDict)
def setInDict(dataDict, mapList, value):
getFromDict(dataDict, mapList[:-1])[mapList[-1]] = value
# This is the return dictionary that stores all the generated code.
code = {}
expression_list = []
expression_keys = []
function_code = {}
for key in expressions.key():
key_list = []
order_list = []
code['main'] = {}
for key in expressions.keys():
if key == 'function':
expression_list.append(expressions[key])
expression_keys.append(key)
function_code[key] = ''
elif key[-9:] == 'derivative':
function_code[key] = {}
for dkey in expressions[key]:
expression_list.append(expressions[key])
key_list.append([key])
order_list.append(1)
code['main'][key] = ''
elif key[-10:] == 'derivative':
code['main'][key] = {}
for dkey in expressions[key].keys():
expression_list.append(expressions[key][dkey])
expression_keys.append([key, dkey])
function_code[key][dkey] = ''
key_list.append([key, dkey])
if key[:-10] == 'first' or key[:-10] == '':
order_list.append(3) #sym.count_ops(expressions[key][dkey]))
elif key[:-10] == 'second':
order_list.append(4) #sym.count_ops(expressions[key][dkey]))
elif key[:-10] == 'third':
order_list.append(5) #sym.count_ops(expressions[key][dkey]))
code['main'][key][dkey] = ''
else:
expression_list.append(expressions[key])
key_list.append([key])
order_list.append(2)
code['main'][key] = ''
symbols, functions = sym.cse(expression_list, numbered_symbols(prefix=prefix))
# This step may be unecessary.
# Not 100% sure if the sub expression elimination is order sensitive. This step orders the list with the 'function' code first and derivatives after.
order_list, expression_list, key_list = zip(*sorted(zip(order_list, expression_list, key_list)))
print expression_list
subexpressions, expression_substituted_list = sym.cse(expression_list, numbered_symbols(prefix=prefix))
cacheable_list = []
# Create strings that lambda strings from the expressions.
sub_expressions = []
for sub, expr in symbols:
code['params_change'] = []
code['cache'] = []
for expr in subexpressions:
arg_list = [e for e in expr[1].atoms() if e.is_Symbol]
cacheable_symbols = [e for e in arg_list if e in cacheable_list or e in cacheable]
if cacheable_symbols:
code['cacheable'].append((expr[0],expr2code(arg_list, expr[1])))
# list which ensures dependencies are cacheable.
cacheable_list.append(expr[0])
code['cacheexpressions'].append(expr[0])
else:
code['params_change'].append((expr[0],expr2code(arg_list, expr[1])))
code['subexpressions'].append(expr[0])
for expr, keys in zip(expression_substituted_list, key_list):
arg_list = [e for e in expr.atoms() if e.is_Symbol]
sub_code += [lambdastr(sorted(arg_list), expr)]
setInDict(code['main'], keys, expr2code(arg_list, expr))
setInDict(expressions, keys, expr)
function_expressions = []
for expr, keys in zip(functions, expression_keys):
arg_list = [e for e in expr.atoms() if e.is_Symbol]
function_code += [lambdastr(sorted(arg_list), expr)]
sub_dict = {}
for i, sub in enumerate(code['cacheexpressions']):
sub_dict[sub.name] = cache_prefix + str(i)
for i, sub in enumerate(code['subexpressions']):
sub_dict[sub.name] = sub_prefix + str(i)
for key in function_code.key():
if key == 'function':
function_code[key] =
return sub_code, func_code
return code
def expr2code(arg_list, expr):
"""Convert the given symbolic expression into code."""
code = lambdastr(arg_list, expr)
function_code = code.split(':')[1]
for arg in arg_list:
function_code = function_code.replace(arg.name, 'self.'+arg.name)
return function_code
class logistic(Function):
"""The logistic function as a symbolic function."""