Allow the default constraint of a Param object to be 'fixed'

This commit is contained in:
Zhenwen Dai 2014-09-02 11:52:09 +01:00
parent e332c1e246
commit 808cfb0501
3 changed files with 12 additions and 7 deletions

View file

@ -303,6 +303,7 @@ class Model(Parameterized):
denominator = (2 * np.dot(dx, gradient)) denominator = (2 * np.dot(dx, gradient))
global_ratio = (f1 - f2) / np.where(denominator == 0., 1e-32, denominator) global_ratio = (f1 - f2) / np.where(denominator == 0., 1e-32, denominator)
global_diff = np.abs(f1 - f2) < tolerance and np.allclose(gradient, 0, atol=tolerance) global_diff = np.abs(f1 - f2) < tolerance and np.allclose(gradient, 0, atol=tolerance)
print self.mpi_comm.rank,global_ratio,global_diff
if global_ratio is np.nan: if global_ratio is np.nan:
global_ratio = 0 global_ratio = 0
return np.abs(1. - global_ratio) < tolerance or global_diff return np.abs(1. - global_ratio) < tolerance or global_diff

View file

@ -13,7 +13,7 @@ Observable Pattern for patameterization
""" """
from transformations import Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED from transformations import Transformation,Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED
import numpy as np import numpy as np
import re import re
import logging import logging
@ -541,7 +541,8 @@ class Indexable(Nameable, Observable):
Constrain the parameter to the given Constrain the parameter to the given
:py:class:`GPy.core.transformations.Transformation`. :py:class:`GPy.core.transformations.Transformation`.
""" """
self.param_array[...] = transform.initialize(self.param_array) if isinstance(transform, Transformation):
self.param_array[...] = transform.initialize(self.param_array)
reconstrained = self.unconstrain() reconstrained = self.unconstrain()
added = self._add_to_index_operations(self.constraints, reconstrained, transform, warning) added = self._add_to_index_operations(self.constraints, reconstrained, transform, warning)
self.notify_observers(self, None if trigger_parent else -np.inf) self.notify_observers(self, None if trigger_parent else -np.inf)
@ -617,7 +618,7 @@ class Indexable(Nameable, Observable):
""" """
Helper preventing copy code. Helper preventing copy code.
This adds the given what (transformation, prior etc) to parameter index operations which. This adds the given what (transformation, prior etc) to parameter index operations which.
revonstrained are reconstrained indices. reconstrained are reconstrained indices.
warn when reconstraining parameters if warning is True. warn when reconstraining parameters if warning is True.
TODO: find out which parameters have changed specifically TODO: find out which parameters have changed specifically
""" """

View file

@ -7,7 +7,7 @@ Created on 6 Nov 2013
import numpy as np import numpy as np
from parameterized import Parameterized from parameterized import Parameterized
from param import Param from param import Param
from transformations import Logexp, Logistic from transformations import Logexp, Logistic,__fixed__
class VariationalPrior(Parameterized): class VariationalPrior(Parameterized):
def __init__(self, name='latent space', **kw): def __init__(self, name='latent space', **kw):
@ -35,12 +35,15 @@ class NormalPrior(VariationalPrior):
class SpikeAndSlabPrior(VariationalPrior): class SpikeAndSlabPrior(VariationalPrior):
def __init__(self, pi=None, learnPi=False, variance = 1.0, name='SpikeAndSlabPrior', **kw): def __init__(self, pi=None, learnPi=False, variance = 1.0, name='SpikeAndSlabPrior', **kw):
super(VariationalPrior, self).__init__(name=name, **kw) super(SpikeAndSlabPrior, self).__init__(name=name, **kw)
self.pi = Param('pi', pi, Logistic(1e-10,1.-1e-10))
self.variance = Param('variance',variance) self.variance = Param('variance',variance)
self.learnPi = learnPi self.learnPi = learnPi
if learnPi: if learnPi:
self.add_parameters(self.pi) self.pi = Param('Pi', pi, Logistic(1e-10,1.-1e-10))
else:
self.pi = Param('Pi', pi, __fixed__)
self.add_parameter(self.pi)
def KL_divergence(self, variational_posterior): def KL_divergence(self, variational_posterior):
mu = variational_posterior.mean mu = variational_posterior.mean