2013-11-07 13:55:24 +00:00
'''
Created on 6 Nov 2013
@author : maxz
'''
2014-02-24 09:49:29 +00:00
import numpy as np
2015-10-15 15:13:16 +01:00
from . parameterized import Parameterized
from . param import Param
2015-10-15 14:59:57 +01:00
from paramz . transformations import Logexp , Logistic , __fixed__
2013-11-07 13:55:24 +00:00
2014-02-25 16:09:26 +00:00
class VariationalPrior ( Parameterized ) :
2014-03-05 10:45:35 +00:00
def __init__ ( self , name = ' latent space ' , * * kw ) :
2014-02-25 16:09:26 +00:00
super ( VariationalPrior , self ) . __init__ ( name = name , * * kw )
2014-03-24 09:06:48 +00:00
2014-02-24 09:49:29 +00:00
def KL_divergence ( self , variational_posterior ) :
2015-02-26 13:33:39 +00:00
raise NotImplementedError ( " override this for variational inference of latent space " )
2014-02-24 09:49:29 +00:00
def update_gradients_KL ( self , variational_posterior ) :
"""
updates the gradients for mean and variance * * in place * *
"""
2015-02-26 13:33:39 +00:00
raise NotImplementedError ( " override this for variational inference of latent space " )
2014-03-24 09:06:48 +00:00
class NormalPrior ( VariationalPrior ) :
2014-02-24 09:49:29 +00:00
def KL_divergence ( self , variational_posterior ) :
var_mean = np . square ( variational_posterior . mean ) . sum ( )
var_S = ( variational_posterior . variance - np . log ( variational_posterior . variance ) ) . sum ( )
return 0.5 * ( var_mean + var_S ) - 0.5 * variational_posterior . input_dim * variational_posterior . num_data
def update_gradients_KL ( self , variational_posterior ) :
# dL:
variational_posterior . mean . gradient - = variational_posterior . mean
variational_posterior . variance . gradient - = ( 1. - ( 1. / ( variational_posterior . variance ) ) ) * 0.5
2014-02-25 16:09:26 +00:00
class SpikeAndSlabPrior ( VariationalPrior ) :
2015-05-21 11:33:37 +01:00
def __init__ ( self , pi = None , learnPi = False , variance = 1.0 , group_spike = False , name = ' SpikeAndSlabPrior ' , * * kw ) :
super ( SpikeAndSlabPrior , self ) . __init__ ( name = name , * * kw )
self . group_spike = group_spike
2014-02-25 16:09:26 +00:00
self . variance = Param ( ' variance ' , variance )
2014-08-11 18:01:23 +01:00
self . learnPi = learnPi
2014-08-11 14:12:43 +01:00
if learnPi :
2014-09-02 11:52:09 +01:00
self . pi = Param ( ' Pi ' , pi , Logistic ( 1e-10 , 1. - 1e-10 ) )
else :
self . pi = Param ( ' Pi ' , pi , __fixed__ )
2014-09-08 08:57:28 +01:00
self . link_parameter ( self . pi )
2014-09-02 11:52:09 +01:00
2014-03-24 09:06:48 +00:00
2014-02-25 16:09:26 +00:00
def KL_divergence ( self , variational_posterior ) :
mu = variational_posterior . mean
S = variational_posterior . variance
2015-05-21 11:33:37 +01:00
if self . group_spike :
gamma = variational_posterior . gamma . values [ 0 ]
else :
gamma = variational_posterior . gamma . values
2014-08-27 09:45:06 +01:00
if len ( self . pi . shape ) == 2 :
2015-03-30 21:49:02 +01:00
idx = np . unique ( variational_posterior . gamma . _raveled_index ( ) / gamma . shape [ - 1 ] )
2014-08-27 09:45:06 +01:00
pi = self . pi [ idx ]
else :
pi = self . pi
2014-08-11 14:12:43 +01:00
var_mean = np . square ( mu ) / self . variance
var_S = ( S / self . variance - np . log ( S ) )
2015-03-30 21:49:02 +01:00
var_gamma = ( gamma * np . log ( gamma / pi ) ) . sum ( ) + ( ( 1 - gamma ) * np . log ( ( 1 - gamma ) / ( 1 - pi ) ) ) . sum ( )
2014-08-11 14:12:43 +01:00
return var_gamma + ( gamma * ( np . log ( self . variance ) - 1. + var_mean + var_S ) ) . sum ( ) / 2.
2014-03-24 09:06:48 +00:00
2014-02-25 16:09:26 +00:00
def update_gradients_KL ( self , variational_posterior ) :
mu = variational_posterior . mean
S = variational_posterior . variance
2015-05-21 11:33:37 +01:00
if self . group_spike :
gamma = variational_posterior . gamma . values [ 0 ]
else :
gamma = variational_posterior . gamma . values
2014-08-26 18:12:41 +01:00
if len ( self . pi . shape ) == 2 :
2015-03-30 21:49:02 +01:00
idx = np . unique ( variational_posterior . gamma . _raveled_index ( ) / gamma . shape [ - 1 ] )
2014-08-26 18:12:41 +01:00
pi = self . pi [ idx ]
else :
pi = self . pi
2014-02-25 16:09:26 +00:00
2015-05-21 11:33:37 +01:00
if self . group_spike :
2015-05-22 14:29:53 +01:00
dgamma = np . log ( ( 1 - pi ) / pi * gamma / ( 1. - gamma ) ) / variational_posterior . num_data
2015-05-21 11:33:37 +01:00
else :
dgamma = np . log ( ( 1 - pi ) / pi * gamma / ( 1. - gamma ) )
variational_posterior . binary_prob . gradient - = dgamma + ( ( np . square ( mu ) + S ) / self . variance - np . log ( S ) + np . log ( self . variance ) - 1. ) / 2.
2014-08-11 14:12:43 +01:00
mu . gradient - = gamma * mu / self . variance
S . gradient - = ( 1. / self . variance - 1. / S ) * gamma / 2.
2014-08-11 18:01:23 +01:00
if self . learnPi :
if len ( self . pi ) == 1 :
self . pi . gradient = ( gamma / self . pi - ( 1. - gamma ) / ( 1. - self . pi ) ) . sum ( )
elif len ( self . pi . shape ) == 1 :
self . pi . gradient = ( gamma / self . pi - ( 1. - gamma ) / ( 1. - self . pi ) ) . sum ( axis = 0 )
else :
2014-08-26 18:12:41 +01:00
self . pi [ idx ] . gradient = ( gamma / self . pi [ idx ] - ( 1. - gamma ) / ( 1. - self . pi [ idx ] ) )
2014-02-24 09:49:29 +00:00
class VariationalPosterior ( Parameterized ) :
2014-03-31 12:45:09 +01:00
def __init__ ( self , means = None , variances = None , name = ' latent space ' , * a , * * kw ) :
2014-03-10 08:21:13 +00:00
super ( VariationalPosterior , self ) . __init__ ( name = name , * a , * * kw )
2014-02-24 09:49:29 +00:00
self . mean = Param ( " mean " , means )
self . variance = Param ( " variance " , variances , Logexp ( ) )
2014-03-07 16:59:41 +00:00
self . ndim = self . mean . ndim
self . shape = self . mean . shape
2014-02-24 09:49:29 +00:00
self . num_data , self . input_dim = self . mean . shape
2014-09-08 08:57:28 +01:00
self . link_parameters ( self . mean , self . variance )
2014-03-10 16:00:35 +00:00
self . num_data , self . input_dim = self . mean . shape
2014-02-24 09:49:29 +00:00
if self . has_uncertain_inputs ( ) :
assert self . variance . shape == self . mean . shape , " need one variance per sample and dimenion "
2014-03-24 09:06:48 +00:00
2014-11-03 13:38:28 +00:00
def set_gradients ( self , grad ) :
self . mean . gradient , self . variance . gradient = grad
2014-03-24 09:06:48 +00:00
def _raveled_index ( self ) :
index = np . empty ( dtype = int , shape = 0 )
size = 0
2014-05-15 11:29:20 +01:00
for p in self . parameters :
2014-03-24 09:06:48 +00:00
index = np . hstack ( ( index , p . _raveled_index ( ) + size ) )
size + = p . _realsize_ if hasattr ( p , ' _realsize_ ' ) else p . size
return index
2014-02-24 09:49:29 +00:00
def has_uncertain_inputs ( self ) :
return not self . variance is None
2014-03-07 16:59:41 +00:00
def __getitem__ ( self , s ) :
2014-03-10 08:21:13 +00:00
if isinstance ( s , ( int , slice , tuple , list , np . ndarray ) ) :
import copy
n = self . __new__ ( self . __class__ , self . name )
dc = self . __dict__ . copy ( )
dc [ ' mean ' ] = self . mean [ s ]
dc [ ' variance ' ] = self . variance [ s ]
2014-05-15 11:29:20 +01:00
dc [ ' parameters ' ] = copy . copy ( self . parameters )
2014-03-10 08:21:13 +00:00
n . __dict__ . update ( dc )
2014-05-15 11:29:20 +01:00
n . parameters [ dc [ ' mean ' ] . _parent_index_ ] = dc [ ' mean ' ]
n . parameters [ dc [ ' variance ' ] . _parent_index_ ] = dc [ ' variance ' ]
2014-05-13 08:35:25 +01:00
n . _gradient_array_ = None
oversize = self . size - self . mean . size - self . variance . size
n . size = n . mean . size + n . variance . size + oversize
2014-03-10 08:21:13 +00:00
n . ndim = n . mean . ndim
n . shape = n . mean . shape
n . num_data = n . mean . shape [ 0 ]
n . input_dim = n . mean . shape [ 1 ] if n . ndim != 1 else 1
return n
else :
2014-08-15 08:39:14 -07:00
return super ( VariationalPosterior , self ) . __getitem__ ( s )
2014-02-24 09:49:29 +00:00
class NormalPosterior ( VariationalPosterior ) :
2013-11-07 13:55:24 +00:00
'''
2014-02-24 09:49:29 +00:00
NormalPosterior distribution for variational approximations .
2014-01-28 13:40:24 +00:00
2013-11-07 13:55:24 +00:00
holds the means and variances for a factorizing multivariate normal distribution
'''
2013-12-09 00:19:37 +00:00
2015-04-08 08:24:00 +02:00
def plot ( self , * args , * * kwargs ) :
2013-12-09 00:19:37 +00:00
"""
Plot latent space X in 1 D :
2014-01-28 13:40:24 +00:00
See GPy . plotting . matplot_dep . variational_plots
2013-12-09 00:19:37 +00:00
"""
2014-02-10 15:12:49 +00:00
import sys
2014-01-28 13:40:24 +00:00
assert " matplotlib " in sys . modules , " matplotlib package has not been imported. "
2014-02-10 15:12:49 +00:00
from . . . plotting . matplot_dep import variational_plots
2015-04-08 08:24:00 +02:00
return variational_plots . plot ( self , * args , * * kwargs )
2014-02-21 17:56:37 +00:00
2015-07-29 10:48:05 +02:00
def KL ( self , other ) :
""" Compute the KL divergence to another NormalPosterior Object. This only holds, if the two NormalPosterior objects have the same shape, as we do computational tricks for the multivariate normal KL divergence.
"""
return .5 * (
np . sum ( self . variance / other . variance )
+ ( ( other . mean - self . mean ) * * 2 / other . variance ) . sum ( )
- self . num_data * self . input_dim
+ np . sum ( np . log ( other . variance ) ) - np . sum ( np . log ( self . variance ) )
)
2014-02-25 16:09:26 +00:00
class SpikeAndSlabPosterior ( VariationalPosterior ) :
2014-02-21 17:56:37 +00:00
'''
The SpikeAndSlab distribution for variational approximations .
'''
2015-05-22 14:29:53 +01:00
def __init__ ( self , means , variances , binary_prob , group_spike = False , sharedX = False , name = ' latent space ' ) :
2014-02-21 17:56:37 +00:00
"""
binary_prob : the probability of the distribution on the slab part .
"""
2014-02-25 16:09:26 +00:00
super ( SpikeAndSlabPosterior , self ) . __init__ ( means , variances , name )
2015-05-21 11:33:37 +01:00
self . group_spike = group_spike
2015-05-22 14:29:53 +01:00
self . sharedX = sharedX
if sharedX :
self . mean . fix ( warning = False )
self . variance . fix ( warning = False )
2015-05-21 11:33:37 +01:00
if group_spike :
2015-05-28 14:19:46 +01:00
self . gamma_group = Param ( " binary_prob_group " , binary_prob . mean ( axis = 0 ) , Logistic ( 1e-10 , 1. - 1e-10 ) )
2015-05-21 11:33:37 +01:00
self . gamma = Param ( " binary_prob " , binary_prob , __fixed__ )
self . link_parameters ( self . gamma_group , self . gamma )
else :
2015-05-28 14:19:46 +01:00
self . gamma = Param ( " binary_prob " , binary_prob , Logistic ( 1e-10 , 1. - 1e-10 ) )
2015-05-21 11:33:37 +01:00
self . link_parameter ( self . gamma )
def propogate_val ( self ) :
if self . group_spike :
2015-05-22 14:29:53 +01:00
self . gamma . values [ : ] = self . gamma_group . values
2015-05-21 11:33:37 +01:00
def collate_gradient ( self ) :
if self . group_spike :
self . gamma_group . gradient = self . gamma . gradient . reshape ( self . gamma . shape ) . sum ( axis = 0 )
2014-03-24 09:06:48 +00:00
2014-11-03 13:38:28 +00:00
def set_gradients ( self , grad ) :
self . mean . gradient , self . variance . gradient , self . gamma . gradient = grad
2014-03-18 12:35:28 +00:00
def __getitem__ ( self , s ) :
if isinstance ( s , ( int , slice , tuple , list , np . ndarray ) ) :
import copy
n = self . __new__ ( self . __class__ , self . name )
dc = self . __dict__ . copy ( )
dc [ ' mean ' ] = self . mean [ s ]
dc [ ' variance ' ] = self . variance [ s ]
dc [ ' binary_prob ' ] = self . binary_prob [ s ]
2014-05-15 11:29:20 +01:00
dc [ ' parameters ' ] = copy . copy ( self . parameters )
2014-03-18 12:35:28 +00:00
n . __dict__ . update ( dc )
2014-05-15 11:29:20 +01:00
n . parameters [ dc [ ' mean ' ] . _parent_index_ ] = dc [ ' mean ' ]
n . parameters [ dc [ ' variance ' ] . _parent_index_ ] = dc [ ' variance ' ]
n . parameters [ dc [ ' binary_prob ' ] . _parent_index_ ] = dc [ ' binary_prob ' ]
2014-06-25 13:57:25 +01:00
n . _gradient_array_ = None
2015-05-18 16:36:39 +01:00
oversize = self . size - self . mean . size - self . variance . size - self . gamma . size
n . size = n . mean . size + n . variance . size + n . gamma . size + oversize
2014-03-18 12:35:28 +00:00
n . ndim = n . mean . ndim
n . shape = n . mean . shape
n . num_data = n . mean . shape [ 0 ]
n . input_dim = n . mean . shape [ 1 ] if n . ndim != 1 else 1
return n
else :
2015-05-18 16:36:39 +01:00
return super ( SpikeAndSlabPosterior , self ) . __getitem__ ( s )
2014-02-21 17:56:37 +00:00
2014-05-15 11:43:29 +01:00
def plot ( self , * args , * * kwargs ) :
2014-02-21 17:56:37 +00:00
"""
Plot latent space X in 1 D :
See GPy . plotting . matplot_dep . variational_plots
"""
import sys
assert " matplotlib " in sys . modules , " matplotlib package has not been imported. "
from . . . plotting . matplot_dep import variational_plots
2014-05-15 11:43:29 +01:00
return variational_plots . plot_SpikeSlab ( self , * args , * * kwargs )