Added factorize_space function which returns a segmentation to shared and private dims

This commit is contained in:
Andreas 2015-10-01 13:07:57 +01:00
parent 858a355396
commit bc5e9a8ad3

View file

@ -342,3 +342,60 @@ class MRD(BayesianGPLVMMiniBatch):
self.kern = self.bgplvms[0].kern
self.likelihood = self.bgplvms[0].likelihood
self.parameters_changed()
def factorize_space(self, threshold=0.005, printOut=False, views=None):
"""
Given a trained MRD model, this function looks at the optimized ARD weights (lengthscales)
and decides which part of the latent space is shared across views or private, according to a threshold.
The threshold is applied after all weights are normalized so that the maximum value is 1.
"""
M = len(self.bgplvms)
if views is None:
# There are some small modifications needed to make this work for M > 2 (currently the code
# takes account of this, but it's not right there)
if M is not 2:
raise NotImplementedError("Not implemented for M > 2")
obsMod = [0]
infMod = 1
else:
obsMod = views[0]
infMod = views[1]
scObs = [None] * len(obsMod)
for i in range(0,len(obsMod)):
# WARNING: the [0] in the end assumes that the ARD kernel (if there's addition) is the 1st one
scObs[i] = np.atleast_2d(self.bgplvms[obsMod[i]].kern.input_sensitivity(summarize=False))[0]
# Normalise to have max 1
scObs[i] /= np.max(scObs[i])
scInf = np.atleast_2d(self.bgplvms[infMod].kern.input_sensitivity(summarize=False))[0]
scInf /= np.max(scInf)
retainedScales = [None]*(len(obsMod)+1)
for i in range(0,len(obsMod)):
retainedScales[obsMod[i]] = np.where(scObs[i] > threshold)[0]
retainedScales[infMod] = np.where(scInf > threshold)[0]
for i in range(len(retainedScales)):
retainedScales[i] = [k for k in retainedScales[i]] # Transform array to list
sharedDims = set(retainedScales[obsMod[0]]).intersection(set(retainedScales[infMod]))
for i in range(1,len(obsMod)):
sharedDims = sharedDims.intersection(set(retainedScales[obsMod[i]]))
privateDims = [None]*M
for i in range(0,len(retainedScales)):
privateDims[i] = set(retainedScales[i]).difference(sharedDims)
privateDims[i] = [k for k in privateDims[i]] # Transform set to list
sharedDims = [k for k in sharedDims] # Transform set to list
sharedDims.sort()
for i in range(len(privateDims)):
privateDims[i].sort()
if printOut:
print '# Shared dimensions: ' + str(sharedDims)
for i in range(len(retainedScales)):
print '# Private dimensions model ' + str(i) + ':' + str(privateDims[i])
return sharedDims, privateDims