fix the remaining problem of cache.py

This commit is contained in:
Zhenwen Dai 2014-06-20 21:51:51 +01:00
parent ca1edecce4
commit 59172435e2
3 changed files with 61 additions and 20 deletions

View file

@ -58,7 +58,23 @@ gpu_code = """
}
}
__global__ void psi1computations(double *psi1, double var, double *l, double *Z, double *mu, double *S, int N, int M, int Q)
__global__ void compDenom(double *log_denom1, double *log_denom2, double *l, double *S, int N, int Q)
{
int n_start, n_end;
divide_data(N, gridDim.x, blockIdx.x, &n_start, &n_end);
for(int i=n_start*Q+threadIdx.x; i<n_end*Q; i+=blockDim.x) {
int n=i/Q;
int q=i%Q;
double Snq = S[IDX_NQ(n,q)];
double lq = l[q]*l[q];
log_denom1[IDX_NQ(n,q)] = log(Snq/lq+1.);
log_denom2[IDX_NQ(n,q)] = log(2.*Snq/lq+1.);
}
}
__global__ void psi1computations(double *psi1, double *log_denom1, double var, double *l, double *Z, double *mu, double *S, int N, int M, int Q)
{
int m_start, m_end;
divide_data(M, gridDim.x, blockIdx.x, &m_start, &m_end);
@ -70,15 +86,14 @@ gpu_code = """
double muZ = mu[IDX_NQ(n,q)]-Z[IDX_MQ(m,q)];
double Snq = S[IDX_NQ(n,q)];
double lq = l[q]*l[q];
log_psi1 += (muZ*muZ/(Snq+lq))/(-2.);
log_psi1 += log(Snq/lq+1)/(-2.);
log_psi1 += (muZ*muZ/(Snq+lq)+log_denom1[IDX_NQ(n,q)])/(-2.);
}
psi1[IDX_NM(n,m)] = var*exp(log_psi1);
}
}
}
__global__ void psi2computations(double *psi2, double *psi2n, double var, double *l, double *Z, double *mu, double *S, int N, int M, int Q)
__global__ void psi2computations(double *psi2, double *psi2n, double *log_denom2, double var, double *l, double *Z, double *mu, double *S, int N, int M, int Q)
{
int psi2_idx_start, psi2_idx_end;
__shared__ double psi2_local[THREADNUM];
@ -96,8 +111,8 @@ gpu_code = """
double muZhat = mu[IDX_NQ(n,q)]- (Z[IDX_MQ(m1,q)]+Z[IDX_MQ(m2,q)])/2.;
double Snq = S[IDX_NQ(n,q)];
double lq = l[q]*l[q];
log_psi2_n += dZ*dZ/(-4.*lq)-muZhat*muZhat/(2.*Snq+lq);
log_psi2_n += log(2.*Snq/lq+1)/(-2.);
log_psi2_n += dZ*dZ/(-4.*lq)-muZhat*muZhat/(2.*Snq+lq) + log_denom2[IDX_NQ(n,q)]/(-2.);
//log_psi2_n += log(2.*Snq/lq+1)/(-2.);
}
double exp_psi2_n = exp(log_psi2_n);
psi2n[IDX_NMM(n,m1,m2)] = var*var*exp_psi2_n;
@ -237,9 +252,15 @@ class PSICOMP_RBF_GPU(PSICOMP_RBF):
self.blocknum = 15
module = SourceModule("#define THREADNUM "+str(self.threadnum)+"\n"+gpu_code)
self.g_psi1computations = module.get_function('psi1computations')
self.g_psi1computations.prepare('PPdPPPPiii')
self.g_psi2computations = module.get_function('psi2computations')
self.g_psi2computations.prepare('PPPdPPPPiii')
self.g_psi1compDer = module.get_function('psi1compDer')
self.g_psi1compDer.prepare('PPPPPPPdPPPPiii')
self.g_psi2compDer = module.get_function('psi2compDer')
self.g_psi2compDer.prepare('PPPPPPPdPPPPiii')
self.g_compDenom = module.get_function('compDenom')
self.g_compDenom.prepare('PPPPii')
def _initGPUCache(self, N, M, Q):
if self.gpuCache == None:
@ -253,6 +274,8 @@ class PSICOMP_RBF_GPU(PSICOMP_RBF):
'psi2n_gpu' :gpuarray.empty((N,M,M),np.float64,order='F'),
'dL_dpsi1_gpu' :gpuarray.empty((N,M),np.float64,order='F'),
'dL_dpsi2_gpu' :gpuarray.empty((M,M),np.float64,order='F'),
'log_denom1_gpu' :gpuarray.empty((N,Q),np.float64,order='F'),
'log_denom2_gpu' :gpuarray.empty((N,Q),np.float64,order='F'),
# derivatives
'dvar_gpu' :gpuarray.empty((self.blocknum,),np.float64, order='F'),
'dl_gpu' :gpuarray.empty((Q,self.blocknum),np.float64, order='F'),
@ -277,6 +300,10 @@ class PSICOMP_RBF_GPU(PSICOMP_RBF):
self.gpuCache['Z_gpu'].set(np.asfortranarray(Z))
self.gpuCache['mu_gpu'].set(np.asfortranarray(mu))
self.gpuCache['S_gpu'].set(np.asfortranarray(S))
N,Q = self.gpuCache['S_gpu'].shape
# t=self.g_compDenom(self.gpuCache['log_denom1_gpu'],self.gpuCache['log_denom2_gpu'],self.gpuCache['l_gpu'],self.gpuCache['S_gpu'], np.int32(N), np.int32(Q), block=(self.threadnum,1,1), grid=(self.blocknum,1),time_kernel=True)
# print 'g_compDenom '+str(t)
self.g_compDenom.prepared_call((self.blocknum,1),(self.threadnum,1,1), self.gpuCache['log_denom1_gpu'].gpudata,self.gpuCache['log_denom2_gpu'].gpudata,self.gpuCache['l_gpu'].gpudata,self.gpuCache['S_gpu'].gpudata, np.int32(N), np.int32(Q))
def reset_derivative(self):
self.gpuCache['dvar_gpu'].fill(0.)
@ -309,11 +336,17 @@ class PSICOMP_RBF_GPU(PSICOMP_RBF):
Z_gpu = self.gpuCache['Z_gpu']
mu_gpu = self.gpuCache['mu_gpu']
S_gpu = self.gpuCache['S_gpu']
log_denom1_gpu = self.gpuCache['log_denom1_gpu']
log_denom2_gpu = self.gpuCache['log_denom2_gpu']
psi0 = np.empty((N,))
psi0[:] = variance
self.g_psi1computations(psi1_gpu, np.float64(variance),l_gpu,Z_gpu,mu_gpu,S_gpu, np.int32(N), np.int32(M), np.int32(Q), block=(self.threadnum,1,1), grid=(self.blocknum,1))
self.g_psi2computations(psi2_gpu, psi2n_gpu, np.float64(variance),l_gpu,Z_gpu,mu_gpu,S_gpu, np.int32(N), np.int32(M), np.int32(Q), block=(self.threadnum,1,1), grid=(self.blocknum,1))
self.g_psi1computations.prepared_call((self.blocknum,1),(self.threadnum,1,1),psi1_gpu.gpudata, log_denom1_gpu.gpudata, np.float64(variance),l_gpu.gpudata,Z_gpu.gpudata,mu_gpu.gpudata,S_gpu.gpudata, np.int32(N), np.int32(M), np.int32(Q))
self.g_psi2computations.prepared_call((self.blocknum,1),(self.threadnum,1,1),psi2_gpu.gpudata, psi2n_gpu.gpudata, log_denom2_gpu.gpudata, np.float64(variance),l_gpu.gpudata,Z_gpu.gpudata,mu_gpu.gpudata,S_gpu.gpudata, np.int32(N), np.int32(M), np.int32(Q))
# t = self.g_psi1computations(psi1_gpu, log_denom1_gpu, np.float64(variance),l_gpu,Z_gpu,mu_gpu,S_gpu, np.int32(N), np.int32(M), np.int32(Q), block=(self.threadnum,1,1), grid=(self.blocknum,1),time_kernel=True)
# print 'g_psi1computations '+str(t)
# t = self.g_psi2computations(psi2_gpu, psi2n_gpu, log_denom2_gpu, np.float64(variance),l_gpu,Z_gpu,mu_gpu,S_gpu, np.int32(N), np.int32(M), np.int32(Q), block=(self.threadnum,1,1), grid=(self.blocknum,1),time_kernel=True)
# print 'g_psi2computations '+str(t)
if self.GPU_direct:
return psi0, psi1_gpu, psi2_gpu
@ -352,8 +385,12 @@ class PSICOMP_RBF_GPU(PSICOMP_RBF):
dL_dpsi0_sum = dL_dpsi0.sum()
self.reset_derivative()
self.g_psi1compDer(dvar_gpu,dl_gpu,dZ_gpu,dmu_gpu,dS_gpu,dL_dpsi1_gpu,psi1_gpu, np.float64(variance),l_gpu,Z_gpu,mu_gpu,S_gpu, np.int32(N), np.int32(M), np.int32(Q), block=(self.threadnum,1,1), grid=(self.blocknum,1))
self.g_psi2compDer(dvar_gpu,dl_gpu,dZ_gpu,dmu_gpu,dS_gpu,dL_dpsi2_gpu,psi2n_gpu, np.float64(variance),l_gpu,Z_gpu,mu_gpu,S_gpu, np.int32(N), np.int32(M), np.int32(Q), block=(self.threadnum,1,1), grid=(self.blocknum,1))
# t=self.g_psi1compDer(dvar_gpu,dl_gpu,dZ_gpu,dmu_gpu,dS_gpu,dL_dpsi1_gpu,psi1_gpu, np.float64(variance),l_gpu,Z_gpu,mu_gpu,S_gpu, np.int32(N), np.int32(M), np.int32(Q), block=(self.threadnum,1,1), grid=(self.blocknum,1),time_kernel=True)
# print 'g_psi1compDer '+str(t)
# t=self.g_psi2compDer(dvar_gpu,dl_gpu,dZ_gpu,dmu_gpu,dS_gpu,dL_dpsi2_gpu,psi2n_gpu, np.float64(variance),l_gpu,Z_gpu,mu_gpu,S_gpu, np.int32(N), np.int32(M), np.int32(Q), block=(self.threadnum,1,1), grid=(self.blocknum,1),time_kernel=True)
# print 'g_psi2compDer '+str(t)
self.g_psi1compDer.prepared_call((self.blocknum,1),(self.threadnum,1,1),dvar_gpu.gpudata,dl_gpu.gpudata,dZ_gpu.gpudata,dmu_gpu.gpudata,dS_gpu.gpudata,dL_dpsi1_gpu.gpudata,psi1_gpu.gpudata, np.float64(variance),l_gpu.gpudata,Z_gpu.gpudata,mu_gpu.gpudata,S_gpu.gpudata, np.int32(N), np.int32(M), np.int32(Q))
self.g_psi2compDer.prepared_call((self.blocknum,1),(self.threadnum,1,1),dvar_gpu.gpudata,dl_gpu.gpudata,dZ_gpu.gpudata,dmu_gpu.gpudata,dS_gpu.gpudata,dL_dpsi2_gpu.gpudata,psi2n_gpu.gpudata, np.float64(variance),l_gpu.gpudata,Z_gpu.gpudata,mu_gpu.gpudata,S_gpu.gpudata, np.int32(N), np.int32(M), np.int32(Q))
dL_dvar = dL_dpsi0_sum + gpuarray.sum(dvar_gpu).get()
sum_axis(grad_mu_gpu,dmu_gpu,N*Q,self.blocknum)

View file

@ -79,7 +79,6 @@ class Stationary(Kern):
#a convenience function, so we can cache dK_dr
return self.dK_dr(self._scaled_dist(X, X2))
@Cache_this(limit=5, ignore_args=(0,))
def _unscaled_dist(self, X, X2=None):
"""
Compute the Euclidean distance between each row of X and X2, or between

View file

@ -30,13 +30,13 @@ class Cacher(object):
self.cached_outputs = {} # point from cache_ids to outputs
self.inputs_changed = {} # point from cache_ids to bools
def combine_inputs(self, args, kw):
def combine_inputs(self, args, kw, ignore_args):
"Combines the args and kw in a unique way, such that ordering of kwargs does not lead to recompute"
return args + tuple(c[1] for c in sorted(kw.items(), key=lambda x: x[0]))
return tuple(a for i,a in enumerate(args) if i not in ignore_args) + tuple(c[1] for c in sorted(kw.items(), key=lambda x: x[0]))
def prepare_cache_id(self, combined_args_kw, ignore_args):
def prepare_cache_id(self, combined_args_kw):
"get the cacheid (conc. string of argument ids in order) ignoring ignore_args"
return "".join(str(id(a)) for i,a in enumerate(combined_args_kw) if i not in ignore_args)
return "".join(str(id(a)) for a in combined_args_kw)
def ensure_cache_length(self, cache_id):
"Ensures the cache is within its limits and has one place free"
@ -45,6 +45,8 @@ class Cacher(object):
cache_id = self.order.popleft()
combined_args_kw = self.cached_inputs[cache_id]
for ind in combined_args_kw:
if ind is None:
continue
ind_id = id(ind)
ref, cache_ids = self.cached_input_ids[ind_id]
if len(cache_ids) == 1 and ref() is not None:
@ -63,6 +65,8 @@ class Cacher(object):
self.order.append(cache_id)
self.cached_inputs[cache_id] = combined_args_kw
for a in combined_args_kw:
if a is None:
continue
ind_id = id(a)
v = self.cached_input_ids.get(ind_id, [weakref.ref(a), []])
v[1].append(cache_id)
@ -82,11 +86,12 @@ class Cacher(object):
return self.operation(*args, **kw)
# 2: prepare_cache_id and get the unique id string for this call
inputs = self.combine_inputs(args, kw)
cache_id = self.prepare_cache_id(inputs, self.ignore_args)
inputs = self.combine_inputs(args, kw, self.ignore_args)
cache_id = self.prepare_cache_id(inputs)
# 2: if anything is not cachable, we will just return the operation, without caching
if reduce(lambda a,b: a or (not isinstance(b, Observable)), inputs, False):
if reduce(lambda a,b: a or (not (isinstance(b, Observable) or (b is None))), inputs, False):
# print '['+self.operation.__name__+'] contain un-cachable arguments!'
return self.operation(*args, **kw)
# 3&4: check whether this cache_id has been cached, then has it changed?
try: