Copy and paste observable_array from repository to try and resolve bizzare merge request.

This commit is contained in:
Neil Lawrence 2014-04-14 10:00:07 +01:00
commit c2d3c82944
30 changed files with 1646 additions and 184 deletions

View file

@ -5,6 +5,7 @@ Created on 27 Feb 2014
'''
from collections import defaultdict
import weakref
def intarray_default_factory():
import numpy as np
@ -41,60 +42,66 @@ class ObservablesList(object):
def __init__(self):
self._poc = []
def remove(self, value):
return self._poc.remove(value)
def __delitem__(self, ind):
return self._poc.__delitem__(ind)
def __setitem__(self, ind, item):
return self._poc.__setitem__(ind, item)
def __getitem__(self, ind):
return self._poc.__getitem__(ind)
p,o,c = self._poc[ind]
return p, o(), c
def remove(self, priority, observable, callble):
"""
"""
self.flush()
for i in range(len(self) - 1, -1, -1):
p,o,c = self[i]
if priority==p and observable==o and callble==c:
del self._poc[i]
def __repr__(self):
return self._poc.__repr__()
def append(self, obj):
return self._poc.append(obj)
def index(self, value):
return self._poc.index(value)
def extend(self, iterable):
return self._poc.extend(iterable)
def add(self, priority, observable, callble):
ins = 0
for pr, _, _ in self:
if priority > pr:
break
ins += 1
self._poc.insert(ins, (priority, weakref.ref(observable), callble))
def __str__(self):
return self._poc.__str__()
ret = []
curr_p = None
for p, o, c in self:
curr = ''
if curr_p != p:
pre = "{!s}: ".format(p)
curr_pre = pre
else: curr_pre = " "*len(pre)
curr_p = p
curr += curr_pre
ret.append(curr + ", ".join(map(repr, [o,c])))
return '\n'.join(ret)
def flush(self):
self._poc = [(p,o,c) for p,o,c in self._poc if o() is not None]
def __iter__(self):
return self._poc.__iter__()
def insert(self, index, obj):
return self._poc.insert(index, obj)
self.flush()
for p, o, c in self._poc:
if o() is not None:
yield p, o(), c
def __len__(self):
self.flush()
return self._poc.__len__()
def __deepcopy__(self, memo):
self.flush()
s = ObservablesList()
import copy
s._poc = copy.deepcopy(self._poc, memo)
return s
def __getstate__(self):
self.flush()
from ...util.caching import Cacher
obs = []
for p, o, c in self:
@ -106,6 +113,6 @@ class ObservablesList(object):
def __setstate__(self, state):
self._poc = []
for p, o, c in state:
self._poc.append((p,o,getattr(o, c)))
self.add(p,o,getattr(o, c))
pass

View file

@ -0,0 +1,137 @@
# Copyright (c) 2012, GPy authors (see AUTHORS.txt).
# Licensed under the BSD 3-clause license (see LICENSE.txt)
__updated__ = '2014-03-31'
import numpy as np
from parameter_core import Observable, Pickleable
class ObsAr(np.ndarray, Pickleable, Observable):
"""
An ndarray which reports changes to its observers.
The observers can add themselves with a callable, which
will be called every time this array changes. The callable
takes exactly one argument, which is this array itself.
"""
__array_priority__ = -1 # Never give back ObsAr
def __new__(cls, input_array, *a, **kw):
if not isinstance(input_array, ObsAr):
obj = np.atleast_1d(np.require(input_array, dtype=np.float64, requirements=['W', 'C'])).view(cls)
else: obj = input_array
#cls.__name__ = "ObsAr" # because of fixed printing of `array` in np printing
super(ObsAr, obj).__init__(*a, **kw)
return obj
def __array_finalize__(self, obj):
# see InfoArray.__array_finalize__ for comments
if obj is None: return
self.observers = getattr(obj, 'observers', None)
def __array_wrap__(self, out_arr, context=None):
return out_arr.view(np.ndarray)
def copy(self):
memo = {}
memo[id(self)] = self
return self.__deepcopy__(memo)
def __deepcopy__(self, memo):
s = self.__new__(self.__class__, input_array=self.view(np.ndarray).copy())
memo[id(self)] = s
import copy
s.__dict__.update(copy.deepcopy(self.__dict__, memo))
return s
def __reduce__(self):
func, args, state = super(ObsAr, self).__reduce__()
return func, args, (state, Pickleable.__getstate__(self))
def __setstate__(self, state):
np.ndarray.__setstate__(self, state[0])
Pickleable.__setstate__(self, state[1])
def __setitem__(self, s, val):
super(ObsAr, self).__setitem__(s, val)
self.notify_observers()
def __getslice__(self, start, stop):
return self.__getitem__(slice(start, stop))
def __setslice__(self, start, stop, val):
return self.__setitem__(slice(start, stop), val)
def __ilshift__(self, *args, **kwargs):
r = np.ndarray.__ilshift__(self, *args, **kwargs)
self.notify_observers()
return r
def __irshift__(self, *args, **kwargs):
r = np.ndarray.__irshift__(self, *args, **kwargs)
self.notify_observers()
return r
def __ixor__(self, *args, **kwargs):
r = np.ndarray.__ixor__(self, *args, **kwargs)
self.notify_observers()
return r
def __ipow__(self, *args, **kwargs):
r = np.ndarray.__ipow__(self, *args, **kwargs)
self.notify_observers()
return r
def __ifloordiv__(self, *args, **kwargs):
r = np.ndarray.__ifloordiv__(self, *args, **kwargs)
self.notify_observers()
return r
def __isub__(self, *args, **kwargs):
r = np.ndarray.__isub__(self, *args, **kwargs)
self.notify_observers()
return r
def __ior__(self, *args, **kwargs):
r = np.ndarray.__ior__(self, *args, **kwargs)
self.notify_observers()
return r
def __itruediv__(self, *args, **kwargs):
r = np.ndarray.__itruediv__(self, *args, **kwargs)
self.notify_observers()
return r
def __idiv__(self, *args, **kwargs):
r = np.ndarray.__idiv__(self, *args, **kwargs)
self.notify_observers()
return r
def __iand__(self, *args, **kwargs):
r = np.ndarray.__iand__(self, *args, **kwargs)
self.notify_observers()
return r
def __imod__(self, *args, **kwargs):
r = np.ndarray.__imod__(self, *args, **kwargs)
self.notify_observers()
return r
def __iadd__(self, *args, **kwargs):
r = np.ndarray.__iadd__(self, *args, **kwargs)
self.notify_observers()
return r
def __imul__(self, *args, **kwargs):
r = np.ndarray.__imul__(self, *args, **kwargs)
self.notify_observers()
return r

View file

@ -59,7 +59,7 @@ class Param(OptimizationHandlable, ObsAr):
import pydot
node = pydot.Node(id(self), shape='record', label=self.name)
G.add_node(node)
for o in self._observer_callables_.keys():
for o in self.observers.keys():
label = o.name if hasattr(o, 'name') else str(o)
observed_node = pydot.Node(id(o), label=label)
G.add_node(observed_node)
@ -324,7 +324,7 @@ class ParamConcatenation(object):
if update:
self.update_all_params()
def values(self):
return numpy.hstack([p.param_array for p in self.params])
return numpy.hstack([p.param_array.flat for p in self.params])
#===========================================================================
# parameter operations:
#===========================================================================

View file

@ -44,22 +44,23 @@ class Observable(object):
def __init__(self, *args, **kwargs):
super(Observable, self).__init__()
from lists_and_dicts import ObservablesList
self._observer_callables_ = ObservablesList()
self.observers = ObservablesList()
def add_observer(self, observer, callble, priority=0):
self._insert_sorted(priority, observer, callble)
self.observers.add(priority, observer, callble)
def remove_observer(self, observer, callble=None):
to_remove = []
for p, obs, clble in self._observer_callables_:
for poc in self.observers:
_, obs, clble = poc
if callble is not None:
if (obs == observer) and (callble == clble):
to_remove.append((p, obs, clble))
to_remove.append(poc)
else:
if obs is observer:
to_remove.append((p, obs, clble))
to_remove.append(poc)
for r in to_remove:
self._observer_callables_.remove(r)
self.observers.remove(*r)
def notify_observers(self, which=None, min_priority=None):
"""
@ -74,21 +75,13 @@ class Observable(object):
if which is None:
which = self
if min_priority is None:
[callble(self, which=which) for _, _, callble in self._observer_callables_]
[callble(self, which=which) for _, _, callble in self.observers]
else:
for p, _, callble in self._observer_callables_:
for p, _, callble in self.observers:
if p <= min_priority:
break
callble(self, which=which)
def _insert_sorted(self, p, o, c):
ins = 0
for pr, _, _ in self._observer_callables_:
if p > pr:
break
ins += 1
self._observer_callables_.insert(ins, (p, o, c))
#===============================================================================
# Foundation framework for parameterized and param objects:
#===============================================================================
@ -192,7 +185,7 @@ class Pickleable(object):
def __getstate__(self):
ignore_list = ([#'_parent_', '_parent_index_',
#'_observer_callables_',
#'observers',
'_param_array_', '_gradient_array_', '_fixes_',
'_Cacher_wrap__cachers']
#+ self.parameter_names(recursive=False)

View file

@ -90,7 +90,7 @@ class Parameterized(Parameterizable):
child_node = child.build_pydot(G)
G.add_edge(pydot.Edge(node, child_node))
for o in self._observer_callables_.keys():
for o in self.observers.keys():
label = o.name if hasattr(o, 'name') else str(o)
observed_node = pydot.Node(id(o), label=label)
G.add_node(observed_node)

View file

@ -40,6 +40,7 @@ class SpikeAndSlabPrior(VariationalPrior):
self.pi = Param('pi', pi, Logistic(1e-10,1.-1e-10))
self.variance = Param('variance',variance)
self.add_parameters(self.pi)
self.group_spike_prob = False
def KL_divergence(self, variational_posterior):
mu = variational_posterior.mean
@ -55,7 +56,11 @@ class SpikeAndSlabPrior(VariationalPrior):
S = variational_posterior.variance
gamma = variational_posterior.binary_prob
gamma.gradient -= np.log((1-self.pi)/self.pi*gamma/(1.-gamma))+(np.square(mu)+S-np.log(S)-1.)/2.
if self.group_spike_prob:
gamma_grad = np.log((1-self.pi)/self.pi*gamma/(1.-gamma))+(np.square(mu)+S-np.log(S)-1.)/2.
gamma.gradient -= gamma_grad.mean(axis=0)
else:
gamma.gradient -= np.log((1-self.pi)/self.pi*gamma/(1.-gamma))+(np.square(mu)+S-np.log(S)-1.)/2.
mu.gradient -= gamma*mu
S.gradient -= (1. - (1. / (S))) * gamma /2.
self.pi.gradient = (gamma/self.pi - (1.-gamma)/(1.-self.pi)).sum(axis=0)