mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-30 14:35:15 +02:00
Allow the default constraint of a Param object to be 'fixed'
This commit is contained in:
parent
e332c1e246
commit
808cfb0501
3 changed files with 12 additions and 7 deletions
|
|
@ -13,7 +13,7 @@ Observable Pattern for patameterization
|
|||
|
||||
"""
|
||||
|
||||
from transformations import Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED
|
||||
from transformations import Transformation,Logexp, NegativeLogexp, Logistic, __fixed__, FIXED, UNFIXED
|
||||
import numpy as np
|
||||
import re
|
||||
import logging
|
||||
|
|
@ -541,7 +541,8 @@ class Indexable(Nameable, Observable):
|
|||
Constrain the parameter to the given
|
||||
:py:class:`GPy.core.transformations.Transformation`.
|
||||
"""
|
||||
self.param_array[...] = transform.initialize(self.param_array)
|
||||
if isinstance(transform, Transformation):
|
||||
self.param_array[...] = transform.initialize(self.param_array)
|
||||
reconstrained = self.unconstrain()
|
||||
added = self._add_to_index_operations(self.constraints, reconstrained, transform, warning)
|
||||
self.notify_observers(self, None if trigger_parent else -np.inf)
|
||||
|
|
@ -617,7 +618,7 @@ class Indexable(Nameable, Observable):
|
|||
"""
|
||||
Helper preventing copy code.
|
||||
This adds the given what (transformation, prior etc) to parameter index operations which.
|
||||
revonstrained are reconstrained indices.
|
||||
reconstrained are reconstrained indices.
|
||||
warn when reconstraining parameters if warning is True.
|
||||
TODO: find out which parameters have changed specifically
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ Created on 6 Nov 2013
|
|||
import numpy as np
|
||||
from parameterized import Parameterized
|
||||
from param import Param
|
||||
from transformations import Logexp, Logistic
|
||||
from transformations import Logexp, Logistic,__fixed__
|
||||
|
||||
class VariationalPrior(Parameterized):
|
||||
def __init__(self, name='latent space', **kw):
|
||||
|
|
@ -35,12 +35,15 @@ class NormalPrior(VariationalPrior):
|
|||
|
||||
class SpikeAndSlabPrior(VariationalPrior):
|
||||
def __init__(self, pi=None, learnPi=False, variance = 1.0, name='SpikeAndSlabPrior', **kw):
|
||||
super(VariationalPrior, self).__init__(name=name, **kw)
|
||||
self.pi = Param('pi', pi, Logistic(1e-10,1.-1e-10))
|
||||
super(SpikeAndSlabPrior, self).__init__(name=name, **kw)
|
||||
self.variance = Param('variance',variance)
|
||||
self.learnPi = learnPi
|
||||
if learnPi:
|
||||
self.add_parameters(self.pi)
|
||||
self.pi = Param('Pi', pi, Logistic(1e-10,1.-1e-10))
|
||||
else:
|
||||
self.pi = Param('Pi', pi, __fixed__)
|
||||
self.add_parameter(self.pi)
|
||||
|
||||
|
||||
def KL_divergence(self, variational_posterior):
|
||||
mu = variational_posterior.mean
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue