mirror of
https://github.com/SheffieldML/GPy.git
synced 2026-05-09 12:02:38 +02:00
slicing support for kernel input dimension
This commit is contained in:
parent
5f3524e7da
commit
db5fd17609
10 changed files with 178 additions and 65 deletions
|
|
@ -9,24 +9,27 @@ class Cacher(object):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, operation, limit=5, ignore_args=()):
|
||||
def __init__(self, operation, limit=5, ignore_args=(), force_kwargs=()):
|
||||
self.limit = int(limit)
|
||||
self.ignore_args = ignore_args
|
||||
self.force_kwargs = force_kwargs
|
||||
self.operation=operation
|
||||
self.cached_inputs = []
|
||||
self.cached_outputs = []
|
||||
self.inputs_changed = []
|
||||
|
||||
def __call__(self, *args):
|
||||
def __call__(self, *args, **kw):
|
||||
"""
|
||||
A wrapper function for self.operation,
|
||||
"""
|
||||
|
||||
#ensure that specified arguments are ignored
|
||||
items = sorted(kw.items(), key=lambda x: x[0])
|
||||
oa_all = args + tuple(a for _,a in items)
|
||||
if len(self.ignore_args) != 0:
|
||||
oa = [a for i,a in enumerate(args) if i not in self.ignore_args]
|
||||
oa = [a for i,a in itertools.chain(enumerate(args), items) if i not in self.ignore_args and i not in self.force_kwargs]
|
||||
else:
|
||||
oa = args
|
||||
oa = oa_all
|
||||
|
||||
# this makes sure we only add an observer once, and that None can be in args
|
||||
observable_args = []
|
||||
|
|
@ -37,8 +40,13 @@ class Cacher(object):
|
|||
#make sure that all the found argument really are observable:
|
||||
#otherswise don't cache anything, pass args straight though
|
||||
if not all([isinstance(arg, Observable) for arg in observable_args]):
|
||||
return self.operation(*args)
|
||||
return self.operation(*args, **kw)
|
||||
|
||||
if len(self.force_kwargs) != 0:
|
||||
# check if there are force args, which force reloading
|
||||
for k in self.force_kwargs:
|
||||
if k in kw and kw[k] is not None:
|
||||
return self.operation(*args, **kw)
|
||||
# TODO: WARNING !!! Cache OFFSWITCH !!! WARNING
|
||||
# return self.operation(*args)
|
||||
|
||||
|
|
@ -48,7 +56,7 @@ class Cacher(object):
|
|||
i = state.index(True)
|
||||
if self.inputs_changed[i]:
|
||||
#(elements of) the args have changed since we last computed: update
|
||||
self.cached_outputs[i] = self.operation(*args)
|
||||
self.cached_outputs[i] = self.operation(*args, **kw)
|
||||
self.inputs_changed[i] = False
|
||||
return self.cached_outputs[i]
|
||||
else:
|
||||
|
|
@ -62,11 +70,11 @@ class Cacher(object):
|
|||
self.cached_outputs.pop(0)
|
||||
|
||||
#compute
|
||||
self.cached_inputs.append(args)
|
||||
self.cached_outputs.append(self.operation(*args))
|
||||
self.cached_inputs.append(oa_all)
|
||||
self.cached_outputs.append(self.operation(*args, **kw))
|
||||
self.inputs_changed.append(False)
|
||||
[a.add_observer(self, self.on_cache_changed) for a in observable_args]
|
||||
return self.cached_outputs[-1]#Max says return.
|
||||
return self.cached_outputs[-1]#return
|
||||
|
||||
def on_cache_changed(self, arg):
|
||||
"""
|
||||
|
|
@ -90,15 +98,16 @@ class Cache_this(object):
|
|||
"""
|
||||
A decorator which can be applied to bound methods in order to cache them
|
||||
"""
|
||||
def __init__(self, limit=5, ignore_args=()):
|
||||
def __init__(self, limit=5, ignore_args=(), force_kwargs=()):
|
||||
self.limit = limit
|
||||
self.ignore_args = ignore_args
|
||||
self.force_args = force_kwargs
|
||||
self.c = None
|
||||
def __call__(self, f):
|
||||
def f_wrap(*args):
|
||||
def f_wrap(*args, **kw):
|
||||
if self.c is None:
|
||||
self.c = Cacher(f, self.limit, ignore_args=self.ignore_args)
|
||||
return self.c(*args)
|
||||
self.c = Cacher(f, self.limit, ignore_args=self.ignore_args, force_kwargs=self.force_args)
|
||||
return self.c(*args, **kw)
|
||||
f_wrap._cacher = self
|
||||
f_wrap.__doc__ = "**cached**\n\n" + (f.__doc__ or "")
|
||||
return f_wrap
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue