mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 12:02:38 +02:00
logexpNeg transformation
This commit is contained in:
parent
b928044f40
commit
a9a4841e5f
1 changed files with 30 additions and 2 deletions
|
|
@ -4,10 +4,10 @@
|
|||
|
||||
import numpy as np
|
||||
from domains import _POSITIVE,_NEGATIVE, _BOUNDED
|
||||
import sys
|
||||
import weakref
|
||||
|
||||
_lim_val = -np.log(sys.float_info.epsilon)
|
||||
_exp_lim_val = np.finfo(np.float64).max
|
||||
_lim_val = np.log(_exp_lim_val)#-np.log(sys.float_info.epsilon)
|
||||
|
||||
#===============================================================================
|
||||
# Fixing constants
|
||||
|
|
@ -34,6 +34,16 @@ class Transformation(object):
|
|||
def initialize(self, f):
|
||||
""" produce a sensible initial value for f(x)"""
|
||||
raise NotImplementedError
|
||||
def plot(self, xlabel=r'transformed $\theta$', ylabel=r'$\theta$', axes=None, *args,**kw):
|
||||
import sys
|
||||
assert "matplotlib" in sys.modules, "matplotlib package has not been imported."
|
||||
import matplotlib.pyplot as plt
|
||||
from ...plotting.matplot_dep import base_plots
|
||||
x = np.linspace(-8,8)
|
||||
base_plots.meanplot(x, self.f(x),axes=axes*args,**kw)
|
||||
axes = plt.gca()
|
||||
axes.set_xlabel(xlabel)
|
||||
axes.set_ylabel(ylabel)
|
||||
def __str__(self):
|
||||
raise NotImplementedError
|
||||
def __repr__(self):
|
||||
|
|
@ -54,6 +64,24 @@ class Logexp(Transformation):
|
|||
return np.abs(f)
|
||||
def __str__(self):
|
||||
return '+ve'
|
||||
|
||||
|
||||
class LogexpNeg(Transformation):
|
||||
domain = _POSITIVE
|
||||
def f(self, x):
|
||||
return np.where(x>_lim_val, -x, -np.log(1. + np.exp(np.clip(x, -np.inf, _lim_val))))
|
||||
#raises overflow warning: return np.where(x>_lim_val, x, np.log(1. + np.exp(x)))
|
||||
def finv(self, f):
|
||||
return np.where(f>_lim_val, 0, np.log(np.exp(-f) - 1.))
|
||||
def gradfactor(self, f):
|
||||
return np.where(f>_lim_val, -1, -1 + np.exp(-f))
|
||||
def initialize(self, f):
|
||||
if np.any(f < 0.):
|
||||
print "Warning: changing parameters to satisfy constraints"
|
||||
return np.abs(f)
|
||||
def __str__(self):
|
||||
return '+ve'
|
||||
|
||||
|
||||
class NegativeLogexp(Transformation):
|
||||
domain = _NEGATIVE
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue