2013-12-16 13:45:24 +00:00
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt)
2013-10-02 12:11:53 +01:00
import itertools
import numpy
2014-02-19 15:32:16 +00:00
from parameter_core import Constrainable , Gradcheckable , Indexable , Parentable , adjust_name_for_printing
2014-02-10 16:01:55 +00:00
from array_core import ObservableArray , ParamList
2013-10-02 12:11:53 +01:00
2013-10-07 07:40:48 +01:00
###### printing
__constraints_name__ = " Constraint "
__index_name__ = " Index "
__tie_name__ = " Tied to "
2014-02-14 13:18:48 +00:00
__priors_name__ = " Prior "
2013-10-07 07:40:48 +01:00
__precision__ = numpy . get_printoptions ( ) [ ' precision ' ] # numpy printing precision used, sublassing numpy ndarray after all
2013-10-18 16:20:01 +01:00
__print_threshold__ = 5
2014-02-06 16:22:08 +00:00
######
2013-12-09 00:19:37 +00:00
2014-02-19 15:32:16 +00:00
class Param ( ObservableArray , Constrainable , Gradcheckable , Indexable , Parentable ) :
2013-10-02 12:11:53 +01:00
"""
2013-10-15 09:01:03 +01:00
Parameter object for GPy models .
2013-10-02 12:11:53 +01:00
2014-02-11 12:20:57 +00:00
: param str name : name of the parameter to be printed
: param input_array : array which this parameter handles
: type input_array : numpy . ndarray
: param default_constraint : The default constraint for this parameter
: type default_constraint :
2014-02-06 16:22:08 +00:00
2013-11-07 20:29:40 +00:00
You can add / remove constraints by calling constrain on the parameter itself , e . g :
2014-02-06 16:22:08 +00:00
2013-10-15 09:01:03 +01:00
- self [ : , 1 ] . constrain_positive ( )
- self [ 0 ] . tie_to ( other )
- self . untie ( )
- self [ : 3 , : ] . unconstrain ( )
- self [ 1 ] . fix ( )
2014-02-06 16:22:08 +00:00
2013-10-15 09:01:03 +01:00
Fixing parameters will fix them to the value they are right now . If you change
the fixed value , it will be fixed to the new value !
2014-02-06 16:22:08 +00:00
2013-12-09 00:19:37 +00:00
See : py : class : ` GPy . core . parameterized . Parameterized ` for more details on constraining etc .
2013-10-07 07:40:48 +01:00
"""
2014-02-11 12:20:57 +00:00
__array_priority__ = - 1 # Never give back Param
2013-11-11 17:46:33 +00:00
_fixes_ = None
2014-02-12 11:31:12 +00:00
_parameters_ = [ ]
2014-02-11 12:20:57 +00:00
def __new__ ( cls , name , input_array , default_constraint = None ) :
2013-10-27 17:04:46 +00:00
obj = numpy . atleast_1d ( super ( Param , cls ) . __new__ ( cls , input_array = input_array ) )
2014-02-13 21:59:08 +00:00
cls . __name__ = " Param "
2013-10-17 14:33:41 +01:00
obj . _current_slice_ = ( slice ( obj . shape [ 0 ] ) , )
2013-10-16 21:08:35 +01:00
obj . _realshape_ = obj . shape
obj . _realsize_ = obj . size
obj . _realndim_ = obj . ndim
2013-10-22 13:39:58 +01:00
obj . _updated_ = False
2013-12-09 00:19:37 +00:00
from index_operations import SetDict
obj . _tied_to_me_ = SetDict ( )
2013-10-17 20:47:41 +01:00
obj . _tied_to_ = [ ]
2013-10-17 14:33:41 +01:00
obj . _original_ = True
2014-01-24 12:22:58 +00:00
obj . gradient = None
2013-10-27 17:04:46 +00:00
return obj
2013-11-07 20:29:40 +00:00
2014-02-11 12:20:57 +00:00
def __init__ ( self , name , input_array , default_constraint = None ) :
super ( Param , self ) . __init__ ( name = name , default_constraint = default_constraint )
2014-02-13 21:59:08 +00:00
2013-10-02 19:18:23 +01:00
def __array_finalize__ ( self , obj ) :
# see InfoArray.__array_finalize__ for comments
if obj is None : return
2013-11-03 13:58:15 +00:00
super ( Param , self ) . __array_finalize__ ( obj )
2013-10-25 15:29:04 +01:00
self . _direct_parent_ = getattr ( obj , ' _direct_parent_ ' , None )
2013-10-16 21:08:35 +01:00
self . _parent_index_ = getattr ( obj , ' _parent_index_ ' , None )
2014-02-11 12:20:57 +00:00
self . _default_constraint_ = getattr ( obj , ' _default_constraint_ ' , None )
2013-12-09 00:19:37 +00:00
self . _current_slice_ = getattr ( obj , ' _current_slice_ ' , None )
2013-10-17 20:47:41 +01:00
self . _tied_to_me_ = getattr ( obj , ' _tied_to_me_ ' , None )
self . _tied_to_ = getattr ( obj , ' _tied_to_ ' , None )
2013-10-16 21:08:35 +01:00
self . _realshape_ = getattr ( obj , ' _realshape_ ' , None )
self . _realsize_ = getattr ( obj , ' _realsize_ ' , None )
self . _realndim_ = getattr ( obj , ' _realndim_ ' , None )
2013-10-22 13:39:58 +01:00
self . _updated_ = getattr ( obj , ' _updated_ ' , None )
2013-10-17 14:33:41 +01:00
self . _original_ = getattr ( obj , ' _original_ ' , None )
2013-12-09 00:19:37 +00:00
self . _name = getattr ( obj , ' name ' , None )
2014-01-24 12:22:58 +00:00
self . gradient = getattr ( obj , ' gradient ' , None )
2014-02-13 21:59:08 +00:00
self . constraints = getattr ( obj , ' constraints ' , None )
2014-02-14 13:18:48 +00:00
self . priors = getattr ( obj , ' priors ' , None )
2014-02-06 16:22:08 +00:00
2013-10-16 21:08:35 +01:00
#===========================================================================
# Pickling operations
#===========================================================================
2013-12-07 18:45:24 +00:00
def __reduce_ex__ ( self ) :
2013-10-16 21:08:35 +01:00
func , args , state = super ( Param , self ) . __reduce__ ( )
2014-02-10 15:12:49 +00:00
return func , args , ( state ,
2013-11-03 13:58:15 +00:00
( self . name ,
2013-10-25 15:29:04 +01:00
self . _direct_parent_ ,
2013-10-16 21:08:35 +01:00
self . _parent_index_ ,
2014-02-11 12:20:57 +00:00
self . _default_constraint_ ,
2013-10-16 21:08:35 +01:00
self . _current_slice_ ,
self . _realshape_ ,
self . _realsize_ ,
self . _realndim_ ,
2013-10-22 13:39:58 +01:00
self . _tied_to_me_ ,
self . _tied_to_ ,
self . _updated_ ,
)
2013-10-16 21:08:35 +01:00
)
2013-11-07 20:29:40 +00:00
2013-10-16 21:08:35 +01:00
def __setstate__ ( self , state ) :
super ( Param , self ) . __setstate__ ( state [ 0 ] )
state = list ( state [ 1 ] )
2013-10-22 13:39:58 +01:00
self . _updated_ = state . pop ( )
self . _tied_to_ = state . pop ( )
self . _tied_to_me_ = state . pop ( )
2013-10-16 21:08:35 +01:00
self . _realndim_ = state . pop ( )
self . _realsize_ = state . pop ( )
self . _realshape_ = state . pop ( )
self . _current_slice_ = state . pop ( )
2014-02-11 12:20:57 +00:00
self . _default_constraint_ = state . pop ( )
2013-10-22 13:39:58 +01:00
self . _parent_index_ = state . pop ( )
2013-10-25 15:29:04 +01:00
self . _direct_parent_ = state . pop ( )
2013-11-03 13:58:15 +00:00
self . name = state . pop ( )
2014-02-19 15:32:16 +00:00
def copy ( self , * args ) :
constr = self . constraints . copy ( )
priors = self . priors . copy ( )
p = Param ( self . name , self . view ( numpy . ndarray ) . copy ( ) , self . _default_constraint_ )
p . constraints = constr
p . priors = priors
return p
2013-10-11 16:44:34 +01:00
#===========================================================================
# get/set parameters
#===========================================================================
2013-11-06 11:40:54 +00:00
def _set_params ( self , param , update = True ) :
2013-10-11 16:44:34 +01:00
self . flat = param
2014-02-14 13:18:48 +00:00
#self._notify_tied_parameters()
2013-11-06 11:40:54 +00:00
self . _notify_observers ( )
2014-02-06 16:22:08 +00:00
2013-10-02 12:11:53 +01:00
def _get_params ( self ) :
2013-10-11 16:44:34 +01:00
return self . flat
2014-02-14 13:18:48 +00:00
2014-01-24 15:07:28 +00:00
def _collect_gradient ( self , target ) :
2014-01-24 16:37:20 +00:00
target [ : ] = self . gradient . flat
2014-02-06 16:22:08 +00:00
2013-10-11 16:44:34 +01:00
#===========================================================================
# Array operations -> done
#===========================================================================
2013-10-02 19:18:23 +01:00
def __getitem__ ( self , s , * args , * * kwargs ) :
2013-10-11 16:44:34 +01:00
if not isinstance ( s , tuple ) :
s = ( s , )
2014-02-10 15:12:49 +00:00
if not reduce ( lambda a , b : a or numpy . any ( b is Ellipsis ) , s , False ) and len ( s ) < = self . ndim :
2013-10-16 21:08:35 +01:00
s + = ( Ellipsis , )
2013-10-27 17:04:46 +00:00
new_arr = super ( Param , self ) . __getitem__ ( s , * args , * * kwargs )
2013-10-17 14:33:41 +01:00
try : new_arr . _current_slice_ = s ; new_arr . _original_ = self . base is new_arr . base
2014-02-10 15:12:49 +00:00
except AttributeError : pass # returning 0d array or float, double etc
2013-10-02 19:18:23 +01:00
return new_arr
2013-10-22 13:39:58 +01:00
def __setitem__ ( self , s , val , update = True ) :
2013-11-06 11:40:54 +00:00
super ( Param , self ) . __setitem__ ( s , val , update = update )
2014-02-14 13:43:43 +00:00
#self._notify_tied_parameters()
2014-02-14 15:05:38 +00:00
if update and self . _s_not_empty ( s ) :
2014-02-14 14:40:32 +00:00
self . _notify_parameters_changed ( )
2014-02-14 13:18:48 +00:00
2013-10-16 21:08:35 +01:00
#===========================================================================
# Index Operations:
#===========================================================================
def _internal_offset ( self ) :
internal_offset = 0
extended_realshape = numpy . cumprod ( ( 1 , ) + self . _realshape_ [ : 0 : - 1 ] ) [ : : - 1 ]
for i , si in enumerate ( self . _current_slice_ [ : self . _realndim_ ] ) :
if numpy . all ( si == Ellipsis ) :
continue
if isinstance ( si , slice ) :
2014-02-06 16:22:08 +00:00
a = si . indices ( self . _realshape_ [ i ] ) [ 0 ]
2013-10-16 21:08:35 +01:00
elif isinstance ( si , ( list , numpy . ndarray , tuple ) ) :
a = si [ 0 ]
else : a = si
2014-02-10 15:12:49 +00:00
if a < 0 :
a = self . _realshape_ [ i ] + a
2013-10-16 21:08:35 +01:00
internal_offset + = a * extended_realshape [ i ]
return internal_offset
2013-10-17 14:33:41 +01:00
def _raveled_index ( self , slice_index = None ) :
2013-10-16 21:08:35 +01:00
# return an index array on the raveled array, which is formed by the current_slice
# of this object
extended_realshape = numpy . cumprod ( ( 1 , ) + self . _realshape_ [ : 0 : - 1 ] ) [ : : - 1 ]
2013-10-17 14:33:41 +01:00
ind = self . _indices ( slice_index )
2014-02-10 15:12:49 +00:00
if ind . ndim < 2 : ind = ind [ : , None ]
return numpy . asarray ( numpy . apply_along_axis ( lambda x : numpy . sum ( extended_realshape * x ) , 1 , ind ) , dtype = int )
2013-10-17 14:33:41 +01:00
def _expand_index ( self , slice_index = None ) :
2013-10-16 21:08:35 +01:00
# this calculates the full indexing arrays from the slicing objects given by get_item for _real..._ attributes
# it basically translates slices to their respective index arrays and turns negative indices around
# it tells you in the second return argument if it has only seen arrays as indices
2013-10-17 14:33:41 +01:00
if slice_index is None :
slice_index = self . _current_slice_
2013-10-16 21:08:35 +01:00
def f ( a ) :
a , b = a
if a not in ( slice ( None ) , Ellipsis ) :
if isinstance ( a , slice ) :
start , stop , step = a . indices ( b )
return numpy . r_ [ start : stop : step ]
2014-02-10 15:12:49 +00:00
elif isinstance ( a , ( list , numpy . ndarray , tuple ) ) :
2013-10-16 21:08:35 +01:00
a = numpy . asarray ( a , dtype = int )
2014-02-10 15:12:49 +00:00
a [ a < 0 ] = b + a [ a < 0 ]
elif a < 0 :
a = b + a
2013-10-16 21:08:35 +01:00
return numpy . r_ [ a ]
return numpy . r_ [ : b ]
2013-10-18 16:20:01 +01:00
return itertools . imap ( f , itertools . izip_longest ( slice_index [ : self . _realndim_ ] , self . _realshape_ , fillvalue = slice ( self . size ) ) )
2014-02-14 13:18:48 +00:00
2013-10-11 16:44:34 +01:00
#===========================================================================
2014-02-14 13:18:48 +00:00
# Convenience
2013-10-11 16:44:34 +01:00
#===========================================================================
@property
2013-10-25 15:29:04 +01:00
def is_fixed ( self ) :
return self . _highest_parent_ . _is_fixed ( self )
2014-02-14 13:43:43 +00:00
#def round(self, decimals=0, out=None):
# view = super(Param, self).round(decimals, out).view(Param)
# view.__array_finalize__(self)
# return view
#round.__doc__ = numpy.round.__doc__
2013-10-25 15:29:04 +01:00
def _get_original ( self , param ) :
return self
2014-02-14 13:18:48 +00:00
2013-10-25 15:29:04 +01:00
#===========================================================================
# Printing -> done
#===========================================================================
@property
def _description_str ( self ) :
2014-02-10 15:12:49 +00:00
if self . size < = 1 : return [ " %f " % self ]
2013-10-25 15:29:04 +01:00
else : return [ str ( self . shape ) ]
2014-02-19 16:54:25 +00:00
def parameter_names ( self , add_self = False , adjust_for_printing = False ) :
if adjust_for_printing :
return [ adjust_name_for_printing ( self . name ) ]
2013-10-25 15:29:04 +01:00
return [ self . name ]
@property
def flattened_parameters ( self ) :
return [ self ]
@property
def parameter_shapes ( self ) :
return [ self . shape ]
@property
def _constraints_str ( self ) :
2014-02-12 17:11:55 +00:00
return [ ' ' . join ( map ( lambda c : str ( c [ 0 ] ) if c [ 1 ] . size == self . _realsize_ else " { " + str ( c [ 0 ] ) + " } " , self . constraints . iteritems ( ) ) ) ]
2013-10-25 15:29:04 +01:00
@property
2014-02-14 13:18:48 +00:00
def _priors_str ( self ) :
return [ ' ' . join ( map ( lambda c : str ( c [ 0 ] ) if c [ 1 ] . size == self . _realsize_ else " { " + str ( c [ 0 ] ) + " } " , self . priors . iteritems ( ) ) ) ]
@property
2013-10-25 15:29:04 +01:00
def _ties_str ( self ) :
return [ t . _short ( ) for t in self . _tied_to_ ] or [ ' ' ]
2013-10-11 16:44:34 +01:00
def __repr__ ( self , * args , * * kwargs ) :
2013-10-25 15:29:04 +01:00
name = " \033 [1m {x:s} \033 [0;0m: \n " . format (
2014-02-19 16:54:25 +00:00
x = self . hirarchy_name ( ) )
2014-02-10 15:12:49 +00:00
return name + super ( Param , self ) . __repr__ ( * args , * * kwargs )
2013-10-17 14:33:41 +01:00
def _ties_for ( self , rav_index ) :
2014-02-10 15:12:49 +00:00
# size = sum(p.size for p in self._tied_to_)
2013-10-18 16:20:01 +01:00
ties = numpy . empty ( shape = ( len ( self . _tied_to_ ) , numpy . size ( rav_index ) ) , dtype = Param )
for i , tied_to in enumerate ( self . _tied_to_ ) :
2013-11-12 12:17:14 +00:00
for t , ind in tied_to . _tied_to_me_ . iteritems ( ) :
2013-10-17 20:47:41 +01:00
if t . _parent_index_ == self . _parent_index_ :
2014-02-10 15:12:49 +00:00
matches = numpy . where ( rav_index [ : , None ] == t . _raveled_index ( ) [ None , : ] )
2013-10-18 16:20:01 +01:00
tt_rav_index = tied_to . _raveled_index ( )
2013-11-12 12:17:14 +00:00
ind_rav_matches = numpy . where ( tt_rav_index == numpy . array ( list ( ind ) ) ) [ 0 ]
if len ( ind ) != 1 : ties [ i , matches [ 0 ] [ ind_rav_matches ] ] = numpy . take ( tt_rav_index , matches [ 1 ] , mode = ' wrap ' ) [ ind_rav_matches ]
else : ties [ i , matches [ 0 ] ] = numpy . take ( tt_rav_index , matches [ 1 ] , mode = ' wrap ' )
2014-02-10 15:12:49 +00:00
return map ( lambda a : sum ( a , [ ] ) , zip ( * [ [ [ tie . flatten ( ) ] if tx != None else [ ] for tx in t ] for t , tie in zip ( ties , self . _tied_to_ ) ] ) )
2013-10-17 14:33:41 +01:00
def _indices ( self , slice_index = None ) :
2013-10-15 09:01:03 +01:00
# get a int-array containing all indices in the first axis.
2013-10-17 14:33:41 +01:00
if slice_index is None :
slice_index = self . _current_slice_
if isinstance ( slice_index , ( tuple , list ) ) :
clean_curr_slice = [ s for s in slice_index if numpy . any ( s != Ellipsis ) ]
2014-02-13 21:59:08 +00:00
for i in range ( self . _realndim_ - len ( clean_curr_slice ) ) :
i + = len ( clean_curr_slice )
clean_curr_slice + = range ( self . _realshape_ [ i ] )
2014-02-06 16:22:08 +00:00
if ( all ( isinstance ( n , ( numpy . ndarray , list , tuple ) ) for n in clean_curr_slice )
2014-02-10 15:12:49 +00:00
and len ( set ( map ( len , clean_curr_slice ) ) ) < = 1 ) :
2013-10-18 16:20:01 +01:00
return numpy . fromiter ( itertools . izip ( * clean_curr_slice ) ,
2014-02-10 15:12:49 +00:00
dtype = [ ( ' ' , int ) ] * self . _realndim_ , count = len ( clean_curr_slice [ 0 ] ) ) . view ( ( int , self . _realndim_ ) )
2013-10-18 16:20:01 +01:00
expanded_index = list ( self . _expand_index ( slice_index ) )
return numpy . fromiter ( itertools . product ( * expanded_index ) ,
2014-02-10 15:12:49 +00:00
dtype = [ ( ' ' , int ) ] * self . _realndim_ , count = reduce ( lambda a , b : a * b . size , expanded_index , 1 ) ) . view ( ( int , self . _realndim_ ) )
2013-10-16 21:08:35 +01:00
def _max_len_names ( self , gen , header ) :
2014-02-14 13:18:48 +00:00
gen = map ( lambda x : " " . join ( map ( str , x ) ) , gen )
2013-10-16 21:08:35 +01:00
return reduce ( lambda a , b : max ( a , len ( b ) ) , gen , len ( header ) )
2013-10-07 07:40:48 +01:00
def _max_len_values ( self ) :
2014-02-19 16:54:25 +00:00
return reduce ( lambda a , b : max ( a , len ( " { x:=. {0} g} " . format ( __precision__ , x = b ) ) ) , self . flat , len ( self . hirarchy_name ( ) ) )
2013-10-07 07:40:48 +01:00
def _max_len_index ( self , ind ) :
2013-10-16 21:08:35 +01:00
return reduce ( lambda a , b : max ( a , len ( str ( b ) ) ) , ind , len ( __index_name__ ) )
2013-10-25 15:29:04 +01:00
def _short ( self ) :
2013-10-17 14:33:41 +01:00
# short string to print
2014-02-19 15:50:13 +00:00
name = self . hirarchy_name ( )
2013-10-16 21:08:35 +01:00
if self . _realsize_ < 2 :
2013-10-25 15:29:04 +01:00
return name
ind = self . _indices ( )
2014-02-10 15:42:47 +00:00
if ind . size > 4 : indstr = ' , ' . join ( map ( str , ind [ : 2 ] ) ) + " ... " + ' , ' . join ( map ( str , ind [ - 2 : ] ) )
2014-02-10 15:12:49 +00:00
else : indstr = ' , ' . join ( map ( str , ind ) )
return name + ' [ ' + indstr + ' ] '
2014-02-14 13:18:48 +00:00
def __str__ ( self , constr_matrix = None , indices = None , prirs = None , ties = None , lc = None , lx = None , li = None , lp = None , lt = None , only_name = False ) :
2013-10-18 16:20:01 +01:00
filter_ = self . _current_slice_
vals = self . flat
if indices is None : indices = self . _indices ( filter_ )
ravi = self . _raveled_index ( filter_ )
2014-02-14 13:18:48 +00:00
if constr_matrix is None : constr_matrix = self . constraints . properties_for ( ravi )
if prirs is None : prirs = self . priors . properties_for ( ravi )
2013-10-16 21:08:35 +01:00
if ties is None : ties = self . _ties_for ( ravi )
2013-10-18 16:20:01 +01:00
ties = [ ' ' . join ( map ( lambda x : x . _short ( ) , t ) ) for t in ties ]
2013-10-14 17:58:16 +01:00
if lc is None : lc = self . _max_len_names ( constr_matrix , __constraints_name__ )
if lx is None : lx = self . _max_len_values ( )
2013-10-18 16:20:01 +01:00
if li is None : li = self . _max_len_index ( indices )
if lt is None : lt = self . _max_len_names ( ties , __tie_name__ )
2014-02-14 13:18:48 +00:00
if lp is None : lp = self . _max_len_names ( prirs , __tie_name__ )
sep = ' - '
header_format = " { i: {5} ^ {2} s} | \033 [1m { x: {5} ^ {1} s} \033 [0;0m | { c: {5} ^ {0} s} | { p: {5} ^ {4} s} | { t: {5} ^ {3} s} "
2014-02-19 16:54:25 +00:00
if only_name : header = header_format . format ( lc , lx , li , lt , lp , ' ' , x = self . hirarchy_name ( ) , c = sep * lc , i = sep * li , t = sep * lt , p = sep * lp ) # nice header for printing
else : header = header_format . format ( lc , lx , li , lt , lp , ' ' , x = self . hirarchy_name ( ) , c = __constraints_name__ , i = __index_name__ , t = __tie_name__ , p = __priors_name__ ) # nice header for printing
2013-10-18 16:20:01 +01:00
if not ties : ties = itertools . cycle ( [ ' ' ] )
2014-02-14 13:18:48 +00:00
return " \n " . join ( [ header ] + [ " { i!s:^ {3} s} | { x: > {1} . {2} g} | { c:^ {0} s} | { p:^ {5} s} | { t:^ {4} s} " . format ( lc , lx , __precision__ , li , lt , lp , x = x , c = " " . join ( map ( str , c ) ) , p = " " . join ( map ( str , p ) ) , t = ( t or ' ' ) , i = i ) for i , x , c , t , p in itertools . izip ( indices , vals , constr_matrix , ties , prirs ) ] ) # return all the constraints with right indices
2014-02-10 15:12:49 +00:00
# except: return super(Param, self).__str__()
2013-10-15 09:01:03 +01:00
2013-10-07 07:40:48 +01:00
class ParamConcatenation ( object ) :
def __init__ ( self , params ) :
2013-10-11 16:44:34 +01:00
"""
Parameter concatenation for convienience of printing regular expression matched arrays
you can index this concatenation as if it was the flattened concatenation
2013-10-15 09:01:03 +01:00
of all the parameters it contains , same for setting parameters ( Broadcasting enabled ) .
See : py : class : ` GPy . core . parameter . Param ` for more details on constraining .
2013-10-11 16:44:34 +01:00
"""
2014-02-10 15:12:49 +00:00
# self.params = params
self . params = ParamList ( [ ] )
2013-10-25 15:29:04 +01:00
for p in params :
for p in p . flattened_parameters :
if p not in self . params :
2014-02-06 16:22:08 +00:00
self . params . append ( p )
2013-10-11 16:44:34 +01:00
self . _param_sizes = [ p . size for p in self . params ]
startstops = numpy . cumsum ( [ 0 ] + self . _param_sizes )
2013-10-22 13:39:58 +01:00
self . _param_slices_ = [ slice ( start , stop ) for start , stop in zip ( startstops , startstops [ 1 : ] ) ]
2013-10-15 09:01:03 +01:00
#===========================================================================
# Get/set items, enable broadcasting
#===========================================================================
2013-10-07 07:40:48 +01:00
def __getitem__ ( self , s ) :
2014-02-06 16:22:08 +00:00
ind = numpy . zeros ( sum ( self . _param_sizes ) , dtype = bool ) ; ind [ s ] = True ;
2013-10-25 15:29:04 +01:00
params = [ p . _get_params ( ) [ ind [ ps ] ] for p , ps in zip ( self . params , self . _param_slices_ ) if numpy . any ( p . _get_params ( ) [ ind [ ps ] ] ) ]
2013-10-11 16:44:34 +01:00
if len ( params ) == 1 : return params [ 0 ]
return ParamConcatenation ( params )
2013-10-22 13:39:58 +01:00
def __setitem__ ( self , s , val , update = True ) :
2014-02-18 15:52:33 +00:00
if isinstance ( val , ParamConcatenation ) :
val = val . _vals ( )
2014-02-06 16:22:08 +00:00
ind = numpy . zeros ( sum ( self . _param_sizes ) , dtype = bool ) ; ind [ s ] = True ;
2013-10-11 16:44:34 +01:00
vals = self . _vals ( ) ; vals [ s ] = val ; del val
2014-02-18 15:52:33 +00:00
[ numpy . place ( p , ind [ ps ] , vals [ ps ] ) and update and p . _notify_parameters_changed ( )
2014-02-14 13:43:43 +00:00
for p , ps in zip ( self . params , self . _param_slices_ ) ]
2013-10-11 16:44:34 +01:00
def _vals ( self ) :
return numpy . hstack ( [ p . _get_params ( ) for p in self . params ] )
2013-10-15 09:01:03 +01:00
#===========================================================================
# parameter operations:
#===========================================================================
2014-02-06 16:22:08 +00:00
def update_all_params ( self ) :
2014-02-18 15:52:33 +00:00
for p in self . params :
p . _notify_parameters_changed ( )
2014-02-06 16:22:08 +00:00
2013-10-14 17:58:16 +01:00
def constrain ( self , constraint , warning = True ) :
2014-02-06 16:22:08 +00:00
[ param . constrain ( constraint , update = False ) for param in self . params ]
self . update_all_params ( )
2013-10-15 09:01:03 +01:00
constrain . __doc__ = Param . constrain . __doc__
2014-02-06 16:22:08 +00:00
2013-10-14 17:58:16 +01:00
def constrain_positive ( self , warning = True ) :
2014-02-06 16:22:08 +00:00
[ param . constrain_positive ( warning , update = False ) for param in self . params ]
self . update_all_params ( )
2013-10-15 09:01:03 +01:00
constrain_positive . __doc__ = Param . constrain_positive . __doc__
2014-02-06 16:22:08 +00:00
2013-10-14 17:58:16 +01:00
def constrain_fixed ( self , warning = True ) :
[ param . constrain_fixed ( warning ) for param in self . params ]
2013-10-15 09:01:03 +01:00
constrain_fixed . __doc__ = Param . constrain_fixed . __doc__
2013-10-16 21:08:35 +01:00
fix = constrain_fixed
2014-02-06 16:22:08 +00:00
2013-10-14 17:58:16 +01:00
def constrain_negative ( self , warning = True ) :
2014-02-06 16:22:08 +00:00
[ param . constrain_negative ( warning , update = False ) for param in self . params ]
self . update_all_params ( )
2013-10-15 09:01:03 +01:00
constrain_negative . __doc__ = Param . constrain_negative . __doc__
2014-02-06 16:22:08 +00:00
2013-10-15 09:01:03 +01:00
def constrain_bounded ( self , lower , upper , warning = True ) :
2014-02-06 16:22:08 +00:00
[ param . constrain_bounded ( lower , upper , warning , update = False ) for param in self . params ]
self . update_all_params ( )
2013-10-15 09:01:03 +01:00
constrain_bounded . __doc__ = Param . constrain_bounded . __doc__
2014-02-06 16:22:08 +00:00
2013-10-22 16:16:54 +01:00
def unconstrain ( self , * constraints ) :
[ param . unconstrain ( * constraints ) for param in self . params ]
2013-10-15 09:01:03 +01:00
unconstrain . __doc__ = Param . unconstrain . __doc__
2014-02-06 16:22:08 +00:00
2013-10-07 07:40:48 +01:00
def unconstrain_negative ( self ) :
[ param . unconstrain_negative ( ) for param in self . params ]
2013-10-15 09:01:03 +01:00
unconstrain_negative . __doc__ = Param . unconstrain_negative . __doc__
2014-02-06 16:22:08 +00:00
2013-10-11 16:44:34 +01:00
def unconstrain_positive ( self ) :
[ param . unconstrain_positive ( ) for param in self . params ]
2013-10-15 09:01:03 +01:00
unconstrain_positive . __doc__ = Param . unconstrain_positive . __doc__
2014-02-06 16:22:08 +00:00
2013-10-11 16:44:34 +01:00
def unconstrain_fixed ( self ) :
[ param . unconstrain_fixed ( ) for param in self . params ]
2013-10-15 09:01:03 +01:00
unconstrain_fixed . __doc__ = Param . unconstrain_fixed . __doc__
2013-10-16 21:08:35 +01:00
unfix = unconstrain_fixed
2014-02-06 16:22:08 +00:00
2013-10-15 09:01:03 +01:00
def unconstrain_bounded ( self , lower , upper ) :
[ param . unconstrain_bounded ( lower , upper ) for param in self . params ]
unconstrain_bounded . __doc__ = Param . unconstrain_bounded . __doc__
2014-02-06 16:22:08 +00:00
2013-10-25 15:29:04 +01:00
def untie ( self , * ties ) :
[ param . untie ( * ties ) for param in self . params ]
2014-02-10 16:01:55 +00:00
def checkgrad ( self , verbose = 0 , step = 1e-6 , tolerance = 1e-3 ) :
return self . params [ 0 ] . _highest_parent_ . _checkgrad ( self , verbose , step , tolerance )
#checkgrad.__doc__ = Gradcheckable.checkgrad.__doc__
2014-02-11 14:06:42 +00:00
2014-02-10 15:12:49 +00:00
__lt__ = lambda self , val : self . _vals ( ) < val
__le__ = lambda self , val : self . _vals ( ) < = val
__eq__ = lambda self , val : self . _vals ( ) == val
__ne__ = lambda self , val : self . _vals ( ) != val
__gt__ = lambda self , val : self . _vals ( ) > val
__ge__ = lambda self , val : self . _vals ( ) > = val
2013-10-07 07:40:48 +01:00
def __str__ ( self , * args , * * kwargs ) :
2013-10-17 14:33:41 +01:00
def f ( p ) :
ind = p . _raveled_index ( )
2014-02-14 13:18:48 +00:00
return p . constraints . properties_for ( ind ) , p . _ties_for ( ind ) , p . priors . properties_for ( ind )
2013-10-25 15:29:04 +01:00
params = self . params
2014-02-14 13:18:48 +00:00
constr_matrices , ties_matrices , prior_matrices = zip ( * map ( f , params ) )
2013-10-25 15:29:04 +01:00
indices = [ p . _indices ( ) for p in params ]
lc = max ( [ p . _max_len_names ( cm , __constraints_name__ ) for p , cm in itertools . izip ( params , constr_matrices ) ] )
lx = max ( [ p . _max_len_values ( ) for p in params ] )
li = max ( [ p . _max_len_index ( i ) for p , i in itertools . izip ( params , indices ) ] )
lt = max ( [ p . _max_len_names ( tm , __tie_name__ ) for p , tm in itertools . izip ( params , ties_matrices ) ] )
2014-02-14 13:18:48 +00:00
lp = max ( [ p . _max_len_names ( pm , __constraints_name__ ) for p , pm in itertools . izip ( params , prior_matrices ) ] )
strings = [ ]
start = True
for p , cm , i , tm , pm in itertools . izip ( params , constr_matrices , indices , ties_matrices , prior_matrices ) :
strings . append ( p . __str__ ( constr_matrix = cm , indices = i , prirs = pm , ties = tm , lc = lc , lx = lx , li = li , lp = lp , lt = lt , only_name = ( 1 - start ) ) )
start = False
2013-10-25 15:29:04 +01:00
return " \n " . join ( strings )
2013-10-07 07:40:48 +01:00
def __repr__ ( self ) :
2014-02-14 13:18:48 +00:00
return " \n " . join ( map ( repr , self . params ) )