mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-10 20:42:39 +02:00
[paramz] wrapping - todo: deprecation warnings
This commit is contained in:
parent
2a0d424d33
commit
8bb7c4bf40
3 changed files with 38 additions and 1 deletions
|
|
@ -8,3 +8,38 @@ from .gp import GP
|
|||
from .svgp import SVGP
|
||||
from .sparse_gp import SparseGP
|
||||
from .mapping import *
|
||||
|
||||
|
||||
#===========================================================================
|
||||
# Handle priors, this needs to be
|
||||
# cleaned up at some point
|
||||
#===========================================================================
|
||||
def randomize(self, rand_gen=None, *args, **kwargs):
|
||||
"""
|
||||
Randomize the model.
|
||||
Make this draw from the prior if one exists, else draw from given random generator
|
||||
|
||||
:param rand_gen: np random number generator which takes args and kwargs
|
||||
:param flaot loc: loc parameter for random number generator
|
||||
:param float scale: scale parameter for random number generator
|
||||
:param args, kwargs: will be passed through to random number generator
|
||||
"""
|
||||
if rand_gen is None:
|
||||
rand_gen = np.random.normal
|
||||
# first take care of all parameters (from N(0,1))
|
||||
x = rand_gen(size=self._size_transformed(), *args, **kwargs)
|
||||
updates = self.update_model()
|
||||
self.update_model(False) # Switch off the updates
|
||||
self.optimizer_array = x # makes sure all of the tied parameters get the same init (since there's only one prior object...)
|
||||
# now draw from prior where possible
|
||||
x = self.param_array.copy()
|
||||
[np.put(x, ind, p.rvs(ind.size)) for p, ind in self.priors.items() if not p is None]
|
||||
unfixlist = np.ones((self.size,),dtype=np.bool)
|
||||
from paramz.transformations import __fixed__
|
||||
unfixlist[self.constraints[__fixed__]] = False
|
||||
self.param_array.flat[unfixlist] = x.view(np.ndarray).ravel()[unfixlist]
|
||||
self.update_model(updates)
|
||||
|
||||
Model.randomize = randomize
|
||||
Param.randomize = randomize
|
||||
Parameterized.randomize = randomize
|
||||
|
|
@ -3,6 +3,8 @@
|
|||
|
||||
from paramz import Param
|
||||
from .priorizable import Priorizable
|
||||
from paramz.transformations import __fixed__
|
||||
import logging, numpy as np
|
||||
|
||||
class Param(Param, Priorizable):
|
||||
pass
|
||||
Loading…
Add table
Add a link
Reference in a new issue