Need to fix missing data in likelihoods.

This commit is contained in:
Neil Lawrence 2014-04-20 23:41:04 +02:00
parent 38f6d6a911
commit 196732b83b
3 changed files with 196 additions and 175 deletions

View file

@ -2,15 +2,21 @@
# Licensed under the BSD 3-clause license (see LICENSE.txt) # Licensed under the BSD 3-clause license (see LICENSE.txt)
import sys import sys
import re
from ..core.parameterization import Parameterized from ..core.parameterization import Parameterized
import numpy as np import numpy as np
import sympy as sym import sympy as sym
from ..core.parameterization import Param from ..core.parameterization import Param
from sympy.utilities.lambdify import lambdastr, _imp_namespace, _get_namespace from sympy.utilities.lambdify import lambdastr, _imp_namespace, _get_namespace
from sympy.utilities.iterables import numbered_symbols from sympy.utilities.iterables import numbered_symbols
from sympy import exp from numpy import exp
from scipy.special import gammaln, gamma, erf, erfc, erfcx, polygamma from scipy.special import gammaln, gamma, erf, erfc, erfcx, polygamma
from GPy.util.functions import normcdf, normcdfln, logistic, logisticln from GPy.util.functions import normcdf, normcdfln, logistic, logisticln
def getFromDict(dataDict, mapList):
return reduce(lambda d, k: d[k], mapList, dataDict)
def setInDict(dataDict, mapList, value):
getFromDict(dataDict, mapList[:-1])[mapList[-1]] = value
class Symbolic_core(): class Symbolic_core():
""" """
@ -21,8 +27,7 @@ class Symbolic_core():
# Base class init, do some basic derivatives etc. # Base class init, do some basic derivatives etc.
# Func_modules sets up the right mapping for functions. # Func_modules sets up the right mapping for functions.
self.func_modules = func_modules func_modules += [{'gamma':gamma,
self.func_modules += [{'gamma':gamma,
'gammaln':gammaln, 'gammaln':gammaln,
'erf':erf, 'erfc':erfc, 'erf':erf, 'erfc':erfc,
'erfcx':erfcx, 'erfcx':erfcx,
@ -37,8 +42,27 @@ class Symbolic_core():
self._set_variables(cacheable) self._set_variables(cacheable)
self._set_derivatives(derivatives) self._set_derivatives(derivatives)
self._set_parameters(parameters) self._set_parameters(parameters)
self.namespace = [globals(), self.__dict__] # Convert the expressions to a list for common sub expression elimination
# We should find the following type of expressions: 'function', 'derivative', 'second_derivative', 'third_derivative'.
self.update_expression_list()
# Apply any global stabilisation operations to expressions.
self.global_stabilize()
# Helper functions to get data in and out of dictionaries.
# this code from http://stackoverflow.com/questions/14692690/access-python-nested-dictionary-items-via-a-list-of-keys
self.extract_sub_expressions()
self._gen_code() self._gen_code()
self._set_namespace(func_modules)
def _set_namespace(self, namespaces):
"""Set the name space for use when calling eval. This needs to contain all the relvant functions for mapping from symbolic python to the numerical python. It also contains variables, cached portions etc."""
self.namespace = {}
for m in namespaces[::-1]:
buf = _get_namespace(m)
self.namespace.update(buf)
self.namespace.update(self.__dict__)
def _set_expressions(self, expressions): def _set_expressions(self, expressions):
"""Extract expressions and variables from the user provided expressions.""" """Extract expressions and variables from the user provided expressions."""
@ -79,7 +103,7 @@ class Symbolic_core():
# Do symbolic work to compute derivatives. # Do symbolic work to compute derivatives.
for key, func in self.expressions.items(): for key, func in self.expressions.items():
self.expressions[key]['derivative'] = {theta.name : sym.diff(func['function'],theta) for theta in derivative_arguments} self.expressions[key]['derivative'] = {theta.name : self.stabilize(sym.diff(func['function'],theta)) for theta in derivative_arguments}
def _set_parameters(self, parameters): def _set_parameters(self, parameters):
"""Add parameters to the model and initialize with given values.""" """Add parameters to the model and initialize with given values."""
@ -92,32 +116,45 @@ class Symbolic_core():
# Add parameter. # Add parameter.
self.add_parameters(Param(theta.name, val, None)) self.add_parameters(Param(theta.name, val, None))
#setattr(self, theta.name, ) #self._set_attribute(theta.name, )
def eval_parameters_changed(self): def eval_parameters_changed(self):
# TODO: place checks for inf/nan in here # TODO: place checks for inf/nan in here
# do all the precomputation codes. # do all the precomputation codes.
for variable, code in sorted(self.code['parameters_change'].iteritems()):
setattr(self, variable, eval(code, *self.namespace))
self.eval_update_cache() self.eval_update_cache()
def eval_update_cache(self, **kwargs): def eval_update_cache(self, **kwargs):
# TODO: place checks for inf/nan in here # TODO: place checks for inf/nan in here
# for all provided keywords # for all provided keywords
for variable, value in kwargs.items():
for var in sorted(self.code['parameters_changed'].keys(), key=lambda x: int(re.findall(r'\d+$', x)[0])):
code = self.code['parameters_changed'][var]
self._set_attribute(var, eval(code, self.namespace))
for var, value in kwargs.items():
# update their cached values # update their cached values
if value is not None: if value is not None:
if variable == 'X' or variable == 'F' or variable == 'Mu': if var == 'X' or var == 'F' or var == 'M':
for i, theta in enumerate(self.variables[variable]): value = np.atleast_2d(value)
setattr(self, theta.name, value[:, i][:, None]) for i, theta in enumerate(self.variables[var]):
elif variable.name == 'Z': self._set_attribute(theta.name, value[:, i][:, None])
for i, theta in enumerate(self.variables[variable]): elif var == 'Y':
setattr(self, theta.name, value[:, i][None, :]) # Y values can be missing.
value = np.atleast_2d(value)
for i, theta in enumerate(self.variables[var]):
self._set_attribute('missing' + str(i), np.isnan(value[:, i]))
self._set_attribute(theta.name, value[:, i][:, None])
elif var == 'Z':
value = np.atleast_2d(value)
for i, theta in enumerate(self.variables[var]):
self._set_attribute(theta.name, value[:, i][None, :])
else: else:
setattr(self, theta.name, value[:, i]) value = np.atleast_1d(value)
for i, theta in enumerate(self.variables[var]):
for variable, code in sorted(self.code['update_cache'].iteritems()): self._set_attribute(theta.name, value[i])
setattr(self, variable, eval(code, *self.namespace)) for var in sorted(self.code['update_cache'].keys(), key=lambda x: int(re.findall(r'\d+$', x)[0])):
code = self.code['update_cache'][var]
self._set_attribute(var, eval(code, self.namespace))
def eval_update_gradients(self, function, partial, **kwargs): def eval_update_gradients(self, function, partial, **kwargs):
# TODO: place checks for inf/nan in here # TODO: place checks for inf/nan in here
@ -126,7 +163,7 @@ class Symbolic_core():
code = self.code[function]['derivative'][theta.name] code = self.code[function]['derivative'][theta.name]
setattr(getattr(self, theta.name), setattr(getattr(self, theta.name),
'gradient', 'gradient',
(partial*eval(code, *self.namespace)).sum()) (partial*eval(code, self.namespace)).sum())
def eval_gradients_X(self, function, partial, **kwargs): def eval_gradients_X(self, function, partial, **kwargs):
if kwargs.has_key('X'): if kwargs.has_key('X'):
@ -134,17 +171,17 @@ class Symbolic_core():
self.eval_update_cache(**kwargs) self.eval_update_cache(**kwargs)
for i, theta in enumerate(self.variables['X']): for i, theta in enumerate(self.variables['X']):
code = self.code[function]['derivative'][theta.name] code = self.code[function]['derivative'][theta.name]
gradients_X[:, i:i+1] = partial*eval(code, *self.namespace) gradients_X[:, i:i+1] = partial*eval(code, self.namespace)
return gradients_X return gradients_X
def eval_function(self, function, **kwargs): def eval_function(self, function, **kwargs):
self.eval_update_cache(**kwargs) self.eval_update_cache(**kwargs)
return eval(self.code[function]['function'], *self.namespace) return eval(self.code[function]['function'], self.namespace)
def code_parameters_changed(self): def code_parameters_changed(self):
# do all the precomputation codes. # do all the precomputation codes.
lcode = '' lcode = ''
for variable, code in sorted(self.code['parameters_change'].iteritems()): for variable, code in sorted(self.code['parameters_changed'].iteritems()):
lcode += self._print_code(variable) + ' = ' + self._print_code(code) + '\n' lcode += self._print_code(variable) + ' = ' + self._print_code(code) + '\n'
return lcode return lcode
@ -159,6 +196,7 @@ class Symbolic_core():
else: else:
reorder = '' reorder = ''
for i, theta in enumerate(self.variables[var]): for i, theta in enumerate(self.variables[var]):
lcode+= "\t" + var + '= np.atleast_2d(' + var + ')'
lcode+= "\t" + self._print_code(theta.name) + ' = ' + var + '[:, ' + str(i) + "]" + reorder + "\n" lcode+= "\t" + self._print_code(theta.name) + ' = ' + var + '[:, ' + str(i) + "]" + reorder + "\n"
for variable, code in sorted(self.code['update_cache'].iteritems()): for variable, code in sorted(self.code['update_cache'].iteritems()):
@ -173,13 +211,15 @@ class Symbolic_core():
lcode += self._print_code(theta.name) + '.gradient = (partial*(' + self._print_code(code) + ')).sum()\n' lcode += self._print_code(theta.name) + '.gradient = (partial*(' + self._print_code(code) + ')).sum()\n'
return lcode return lcode
def code_gradients_X(self, function): def code_gradients_cacheable(self, function, variable):
lcode = 'gradients_X = np.zeros_like(X)\n' if variable not in self.cacheable:
raise RuntimeError, variable + ' must be a cacheable.'
lcode = 'gradients_' + variable + ' = np.zeros_like(' + variable + ')\n'
lcode += 'self.update_cache(' + ', '.join(self.cacheable) + ')\n' lcode += 'self.update_cache(' + ', '.join(self.cacheable) + ')\n'
for i, theta in enumerate(self.variables['X']): for i, theta in enumerate(self.variables[variable]):
code = self.code[function]['derivative'][theta.name] code = self.code[function]['derivative'][theta.name]
lcode += 'gradients_X[:, ' + str(i) + ':' + str(i) + '+1] = partial*' + self._print_code(code) + '\n' lcode += 'gradients_' + variable + '[:, ' + str(i) + ':' + str(i) + '+1] = partial*' + self._print_code(code) + '\n'
lcode += 'return gradients_X\n' lcode += 'return gradients_' + variable + '\n'
return lcode return lcode
def code_function(self, function): def code_function(self, function):
@ -187,58 +227,21 @@ class Symbolic_core():
lcode += 'return ' + self._print_code(self.code[function]['function']) lcode += 'return ' + self._print_code(self.code[function]['function'])
return lcode return lcode
def stabilise(self): def stabilize(self, expr):
"""Stabilize the code in the model.""" """Stabilize the code in the model."""
# this code is applied to all expressions in the model in an attempt to sabilize them. # this code is applied to expressions in the model in an attempt to sabilize them.
return expr
def global_stabilize(self):
"""Stabilize all code in the model."""
pass pass
def _gen_namespace(self, modules=None, use_imps=True): def _set_attribute(self, name, value):
"""Gets the relevant namespaces for the given expressions.""" """Make sure namespace gets updated when setting attributes."""
from sympy.core.symbol import Symbol setattr(self, name, value)
self.namespace.update({name: getattr(self, name)})
# If the user hasn't specified any modules, use what is available.
module_provided = True
if modules is None:
module_provided = False
# Use either numpy (if available) or python.math where possible.
# XXX: This leads to different behaviour on different systems and
# might be the reason for irreproducible errors.
modules = ["math", "mpmath", "sympy"]
try:
_import("numpy")
except ImportError:
pass
else:
modules.insert(1, "numpy")
# Get the needed namespaces.
namespaces = []
# First find any function implementations
if use_imps:
for expr in self._expression_list:
namespaces.append(_imp_namespace(expr))
# Check for dict before iterating
if isinstance(modules, (dict, str)) or not hasattr(modules, '__iter__'):
namespaces.append(modules)
else:
namespaces += list(modules)
# fill namespace with first having highest priority
namespace = {}
for m in namespaces[::-1]:
buf = _get_namespace(m)
namespace.update(buf)
for expr in self._expression_list:
if hasattr(expr, "atoms"):
#Try if you can extract symbols from the expression.
#Move on if expr.atoms in not implemented.
syms = expr.atoms(Symbol)
for term in syms:
namespace.update({str(term): term})
return namespace
def update_expression_list(self): def update_expression_list(self):
"""Extract a list of expressions from the dictionary of expressions.""" """Extract a list of expressions from the dictionary of expressions."""
self.expression_list = [] # code arrives in dictionary, but is passed in this list self.expression_list = [] # code arrives in dictionary, but is passed in this list
@ -250,123 +253,141 @@ class Symbolic_core():
self.expression_list.append(texpressions) self.expression_list.append(texpressions)
self.expression_keys.append([fname, type]) self.expression_keys.append([fname, type])
self.expression_order.append(1) self.expression_order.append(1)
self.code[fname] = {type: ''}
elif type[-10:] == 'derivative': elif type[-10:] == 'derivative':
self.code[fname] = {type:{}}
for dtype, expression in texpressions.items(): for dtype, expression in texpressions.items():
self.expression_list.append(expression) self.expression_list.append(expression)
self.expression_keys.append([fname, type, dtype]) self.expression_keys.append([fname, type, dtype])
if type[:-10] == 'first' or type[:-10] == '': if type[:-10] == 'first_' or type[:-10] == '':
self.expression_order.append(3) #sym.count_ops(self.expressions[type][dtype])) self.expression_order.append(3) #sym.count_ops(self.expressions[type][dtype]))
elif type[:-10] == 'second': elif type[:-10] == 'second_':
self.expression_order.append(4) #sym.count_ops(self.expressions[type][dtype])) self.expression_order.append(4) #sym.count_ops(self.expressions[type][dtype]))
elif type[:-10] == 'third': elif type[:-10] == 'third_':
self.expression_order.append(5) #sym.count_ops(self.expressions[type][dtype])) self.expression_order.append(5) #sym.count_ops(self.expressions[type][dtype]))
self.code[fname][type][dtype] = ''
else: else:
self.expression_list.append(fexpressions[type]) self.expression_list.append(fexpressions[type])
self.expression_keys.append([fname, type]) self.expression_keys.append([fname, type])
self.expression_order.append(2) self.expression_order.append(2)
self.code[fname][type] = ''
# This step may be unecessary. # This step may be unecessary.
# Not 100% sure if the sub expression elimination is order sensitive. This step orders the list with the 'function' code first and derivatives after. # Not 100% sure if the sub expression elimination is order sensitive. This step orders the list with the 'function' code first and derivatives after.
self.expression_order, self.expression_list, self.expression_keys = zip(*sorted(zip(self.expression_order, self.expression_list, self.expression_keys))) self.expression_order, self.expression_list, self.expression_keys = zip(*sorted(zip(self.expression_order, self.expression_list, self.expression_keys)))
def extract_sub_expressions(self, cache_prefix='cache', sub_prefix='sub', prefix='XoXoXoX'):
# Do the common sub expression elimination.
common_sub_expressions, expression_substituted_list = sym.cse(self.expression_list, numbered_symbols(prefix=prefix))
def _gen_code(self, cache_prefix = 'cache', sub_prefix = 'sub', prefix='XoXoXoX'): self.variables[cache_prefix] = []
"""Generate code for the list of expressions provided using the common sub-expression eliminator to separate out portions that are computed multiple times.""" self.variables[sub_prefix] = []
# This is the dictionary that stores all the generated code.
self.code = {}
# Convert the expressions to a list for common sub expression elimination # Create dictionary of new sub expressions
# We should find the following type of expressions: 'function', 'derivative', 'second_derivative', 'third_derivative'. sub_expression_dict = {}
self.update_expression_list() for var, void in common_sub_expressions:
sub_expression_dict[var.name] = var
# Apply any global stabilisation operations to expressions.
self.stabilise()
# Helper functions to get data in and out of dictionaries.
# this code from http://stackoverflow.com/questions/14692690/access-python-nested-dictionary-items-via-a-list-of-keys
def getFromDict(dataDict, mapList):
return reduce(lambda d, k: d[k], mapList, dataDict)
def setInDict(dataDict, mapList, value):
getFromDict(dataDict, mapList[:-1])[mapList[-1]] = value
# Do the common sub expression elimination
subexpressions, expression_substituted_list = sym.cse(self.expression_list, numbered_symbols(prefix=prefix))
cacheable_list = []
# Sort out any expression that's dependent on something that scales with data size (these are listed in cacheable). # Sort out any expression that's dependent on something that scales with data size (these are listed in cacheable).
self.expressions['parameters_change'] = [] cacheable_list = []
self.expressions['update_cache'] = [] params_change_list = []
cache_expressions_list = [] # common_sube_expressions contains a list of paired tuples with the new variable and what it equals
sub_expression_list = [] for var, expr in common_sub_expressions:
for expr in subexpressions: arg_list = [e for e in expr.atoms() if e.is_Symbol]
arg_list = [e for e in expr[1].atoms() if e.is_Symbol] # List any cacheable dependencies of the sub-expression
cacheable_symbols = [e for e in arg_list if e in cacheable_list or e in self.cacheable_vars] cacheable_symbols = [e for e in arg_list if e in cacheable_list or e in self.cacheable_vars]
if cacheable_symbols: if cacheable_symbols:
self.expressions['update_cache'].append((expr[0].name, self._expr2code(arg_list, expr[1])))
# list which ensures dependencies are cacheable. # list which ensures dependencies are cacheable.
cacheable_list.append(expr[0]) cacheable_list.append(var)
cache_expressions_list.append(expr[0].name)
else: else:
self.expressions['parameters_change'].append((expr[0].name, self._expr2code(arg_list, expr[1]))) params_change_list.append(var)
sub_expression_list.append(expr[0].name)
replace_dict = {}
for i, expr in enumerate(cacheable_list):
sym_var = sym.var(cache_prefix + str(i))
self.variables[cache_prefix].append(sym_var)
replace_dict[expr.name] = sym_var
for i, expr in enumerate(params_change_list):
sym_var = sym.var(sub_prefix + str(i))
self.variables[sub_prefix].append(sym_var)
replace_dict[expr.name] = sym_var
for replace, void in common_sub_expressions:
for expr, keys in zip(expression_substituted_list, self.expression_keys):
setInDict(self.expressions, keys, expr.subs(replace, replace_dict[replace.name]))
for void, expr in common_sub_expressions:
expr = expr.subs(replace, replace_dict[replace.name])
# Replace original code with code including subexpressions. # Replace original code with code including subexpressions.
for expr, keys in zip(expression_substituted_list, self.expression_keys): for keys in self.expression_keys:
for replace, void in common_sub_expressions:
setInDict(self.expressions, keys, getFromDict(self.expressions, keys).subs(replace, replace_dict[replace.name]))
self.expressions['parameters_changed'] = {}
self.expressions['update_cache'] = {}
for var, expr in common_sub_expressions:
for replace, void in common_sub_expressions:
expr = expr.subs(replace, replace_dict[replace.name])
if var in cacheable_list:
self.expressions['update_cache'][replace_dict[var.name].name] = expr
else:
self.expressions['parameters_changed'][replace_dict[var.name].name] = expr
# for var, expr in common_sub_expressions:
# if var in list(cacheable_list):
# self.expressions['update_cache'].append({var.name: expr.subarg_list = [e for e in sub_expr_pair[1].atoms() if e.is_Symbol]
# cacheable_symbols = [e for e in arg_list if e in cacheable_list or e in self.cacheable_vars]
# if cacheable_symbols:
# self.expressions['update_cache'].append((sub_expr_pair[0].name, self._expr2code(arg_list, sub_expr_pair[1])))
# # list which ensures dependencies are cacheable.
# cacheable_list.append(sub_expr_pair[0])
# cache_expressions_list.append(sub_expr_pair[0].name)
# else:
# self.expressions['parameters_change'].append((sub_expr_pair[0].name, self._expr2code(arg_list, sub_expr_pair[1])))
# sub_expression_list.append(sub_expr_pair[0].name)
def _gen_code(self):
"""Generate code for the list of expressions provided using the common sub-expression eliminator to separate out portions that are computed multiple times."""
# This is the dictionary that stores all the generated code.
self.code = {}
def match_key(expr):
if type(expr) is dict:
code = {}
for key in expr.keys():
code[key] = match_key(expr[key])
else:
arg_list = [e for e in expr.atoms() if e.is_Symbol] arg_list = [e for e in expr.atoms() if e.is_Symbol]
setInDict(self.code, keys, self._expr2code(arg_list, expr)) code = self._expr2code(arg_list, expr)
setInDict(self.expressions, keys, expr) return code
# Create variable names for cache and sub expression portions self.code = match_key(self.expressions)
cache_dict = {}
self.variables[cache_prefix] = []
for i, sub in enumerate(cache_expressions_list):
name = cache_prefix + str(i)
cache_dict[sub] = name
self.variables[cache_prefix].append(sym.var(name))
sub_dict = {}
self.variables[sub_prefix] = []
for i, sub in enumerate(sub_expression_list):
name = sub_prefix + str(i)
sub_dict[sub] = name
self.variables[sub_prefix].append(sym.var(name))
# Replace sub expressions in main code with either cacheN or subN. # for keys in self.expression_keys:
for key, val in cache_dict.iteritems(): # expr = getFromDict(self.expressions, keys)
for keys in self.expression_keys: # arg_list = [e for e in expr.atoms() if e.is_Symbol]
setInDict(self.code, keys, # setInDict(self.code, keys, self._expr2code(arg_list, expr))
getFromDict(self.code,keys).replace(key, val))
for key, val in sub_dict.iteritems(): # # Set up precompute code as either cacheN or subN.
for keys in self.expression_keys: # self.code['update_cache'] = {}
setInDict(self.code, keys, # for key, val in self.expressions['update_cache']:
getFromDict(self.code,keys).replace(key, val)) # expr = val
# for key2, val2 in cache_dict.iteritems():
# expr = expr.replace(key2, val2.name)
# for key2, val2 in sub_dict.iteritems():
# expr = expr.replace(key2, val2.name)
# self.code['update_cache'][cache_dict[key]] = expr
# Set up precompute code as either cacheN or subN. # self.expressions['update_cache'] = dict(self.expressions['update_cache'])
self.code['update_cache'] = {} # self.code['parameters_change'] = {}
for key, val in self.expressions['update_cache']: # for key, val in self.expressions['parameters_change']:
expr = val # expr = val
for key2, val2 in cache_dict.iteritems(): # for key2, val2 in cache_dict.iteritems():
expr = expr.replace(key2, val2) # expr = expr.replace(key2, val2.name)
for key2, val2 in sub_dict.iteritems(): # for key2, val2 in sub_dict.iteritems():
expr = expr.replace(key2, val2) # expr = expr.replace(key2, val2.name)
self.code['update_cache'][cache_dict[key]] = expr # self.code['parameters_change'][sub_dict[key]] = expr
# self.expressions['parameters_change'] = dict(self.expressions['parameters_change'])
self.expressions['update_cache'] = dict(self.expressions['update_cache'])
self.code['parameters_change'] = {}
for key, val in self.expressions['parameters_change']:
expr = val
for key2, val2 in cache_dict.iteritems():
expr = expr.replace(key2, val2)
for key2, val2 in sub_dict.iteritems():
expr = expr.replace(key2, val2)
self.code['parameters_change'][sub_dict[key]] = expr
self.expressions['parameters_change'] = dict(self.expressions['parameters_change'])
def _expr2code(self, arg_list, expr): def _expr2code(self, arg_list, expr):
"""Convert the given symbolic expression into code.""" """Convert the given symbolic expression into code."""

View file

@ -14,9 +14,9 @@ except ImportError:
sympy_available=False sympy_available=False
if sympy_available: if sympy_available:
# These are likelihoods that rely on symbolic. # These are likelihoods that rely on symbolic.
from symbolic import Symbolic from symbolic2 import Symbolic
#from sstudent_t import SstudentT #from sstudent_t import SstudentT
from negative_binomial import Negative_binomial #from negative_binomial import Negative_binomial
#from skew_normal import Skew_normal from skew_normal2 import Skew_normal
#from skew_exponential import Skew_exponential #from skew_exponential import Skew_exponential
#from null_category import Null_category #from null_category import Null_category

View file

@ -26,7 +26,7 @@ class Symbolic(Mapping, Symbolic_core):
self.parameters_changed() self.parameters_changed()
def _initialize_cache(self): def _initialize_cache(self):
self.x_0 = np.random.normal(size=(3, self.input_dim)) self._set_attribute('x_0', np.random.normal(size=(3, self.input_dim)))
def parameters_changed(self): def parameters_changed(self):