fix the remaining problem of cache.py

This commit is contained in:
Zhenwen Dai 2014-06-20 21:51:51 +01:00
parent ca1edecce4
commit 59172435e2
3 changed files with 61 additions and 20 deletions

View file

@ -58,7 +58,23 @@ gpu_code = """
} }
} }
__global__ void psi1computations(double *psi1, double var, double *l, double *Z, double *mu, double *S, int N, int M, int Q) __global__ void compDenom(double *log_denom1, double *log_denom2, double *l, double *S, int N, int Q)
{
int n_start, n_end;
divide_data(N, gridDim.x, blockIdx.x, &n_start, &n_end);
for(int i=n_start*Q+threadIdx.x; i<n_end*Q; i+=blockDim.x) {
int n=i/Q;
int q=i%Q;
double Snq = S[IDX_NQ(n,q)];
double lq = l[q]*l[q];
log_denom1[IDX_NQ(n,q)] = log(Snq/lq+1.);
log_denom2[IDX_NQ(n,q)] = log(2.*Snq/lq+1.);
}
}
__global__ void psi1computations(double *psi1, double *log_denom1, double var, double *l, double *Z, double *mu, double *S, int N, int M, int Q)
{ {
int m_start, m_end; int m_start, m_end;
divide_data(M, gridDim.x, blockIdx.x, &m_start, &m_end); divide_data(M, gridDim.x, blockIdx.x, &m_start, &m_end);
@ -70,15 +86,14 @@ gpu_code = """
double muZ = mu[IDX_NQ(n,q)]-Z[IDX_MQ(m,q)]; double muZ = mu[IDX_NQ(n,q)]-Z[IDX_MQ(m,q)];
double Snq = S[IDX_NQ(n,q)]; double Snq = S[IDX_NQ(n,q)];
double lq = l[q]*l[q]; double lq = l[q]*l[q];
log_psi1 += (muZ*muZ/(Snq+lq))/(-2.); log_psi1 += (muZ*muZ/(Snq+lq)+log_denom1[IDX_NQ(n,q)])/(-2.);
log_psi1 += log(Snq/lq+1)/(-2.);
} }
psi1[IDX_NM(n,m)] = var*exp(log_psi1); psi1[IDX_NM(n,m)] = var*exp(log_psi1);
} }
} }
} }
__global__ void psi2computations(double *psi2, double *psi2n, double var, double *l, double *Z, double *mu, double *S, int N, int M, int Q) __global__ void psi2computations(double *psi2, double *psi2n, double *log_denom2, double var, double *l, double *Z, double *mu, double *S, int N, int M, int Q)
{ {
int psi2_idx_start, psi2_idx_end; int psi2_idx_start, psi2_idx_end;
__shared__ double psi2_local[THREADNUM]; __shared__ double psi2_local[THREADNUM];
@ -96,8 +111,8 @@ gpu_code = """
double muZhat = mu[IDX_NQ(n,q)]- (Z[IDX_MQ(m1,q)]+Z[IDX_MQ(m2,q)])/2.; double muZhat = mu[IDX_NQ(n,q)]- (Z[IDX_MQ(m1,q)]+Z[IDX_MQ(m2,q)])/2.;
double Snq = S[IDX_NQ(n,q)]; double Snq = S[IDX_NQ(n,q)];
double lq = l[q]*l[q]; double lq = l[q]*l[q];
log_psi2_n += dZ*dZ/(-4.*lq)-muZhat*muZhat/(2.*Snq+lq); log_psi2_n += dZ*dZ/(-4.*lq)-muZhat*muZhat/(2.*Snq+lq) + log_denom2[IDX_NQ(n,q)]/(-2.);
log_psi2_n += log(2.*Snq/lq+1)/(-2.); //log_psi2_n += log(2.*Snq/lq+1)/(-2.);
} }
double exp_psi2_n = exp(log_psi2_n); double exp_psi2_n = exp(log_psi2_n);
psi2n[IDX_NMM(n,m1,m2)] = var*var*exp_psi2_n; psi2n[IDX_NMM(n,m1,m2)] = var*var*exp_psi2_n;
@ -237,9 +252,15 @@ class PSICOMP_RBF_GPU(PSICOMP_RBF):
self.blocknum = 15 self.blocknum = 15
module = SourceModule("#define THREADNUM "+str(self.threadnum)+"\n"+gpu_code) module = SourceModule("#define THREADNUM "+str(self.threadnum)+"\n"+gpu_code)
self.g_psi1computations = module.get_function('psi1computations') self.g_psi1computations = module.get_function('psi1computations')
self.g_psi1computations.prepare('PPdPPPPiii')
self.g_psi2computations = module.get_function('psi2computations') self.g_psi2computations = module.get_function('psi2computations')
self.g_psi2computations.prepare('PPPdPPPPiii')
self.g_psi1compDer = module.get_function('psi1compDer') self.g_psi1compDer = module.get_function('psi1compDer')
self.g_psi1compDer.prepare('PPPPPPPdPPPPiii')
self.g_psi2compDer = module.get_function('psi2compDer') self.g_psi2compDer = module.get_function('psi2compDer')
self.g_psi2compDer.prepare('PPPPPPPdPPPPiii')
self.g_compDenom = module.get_function('compDenom')
self.g_compDenom.prepare('PPPPii')
def _initGPUCache(self, N, M, Q): def _initGPUCache(self, N, M, Q):
if self.gpuCache == None: if self.gpuCache == None:
@ -253,6 +274,8 @@ class PSICOMP_RBF_GPU(PSICOMP_RBF):
'psi2n_gpu' :gpuarray.empty((N,M,M),np.float64,order='F'), 'psi2n_gpu' :gpuarray.empty((N,M,M),np.float64,order='F'),
'dL_dpsi1_gpu' :gpuarray.empty((N,M),np.float64,order='F'), 'dL_dpsi1_gpu' :gpuarray.empty((N,M),np.float64,order='F'),
'dL_dpsi2_gpu' :gpuarray.empty((M,M),np.float64,order='F'), 'dL_dpsi2_gpu' :gpuarray.empty((M,M),np.float64,order='F'),
'log_denom1_gpu' :gpuarray.empty((N,Q),np.float64,order='F'),
'log_denom2_gpu' :gpuarray.empty((N,Q),np.float64,order='F'),
# derivatives # derivatives
'dvar_gpu' :gpuarray.empty((self.blocknum,),np.float64, order='F'), 'dvar_gpu' :gpuarray.empty((self.blocknum,),np.float64, order='F'),
'dl_gpu' :gpuarray.empty((Q,self.blocknum),np.float64, order='F'), 'dl_gpu' :gpuarray.empty((Q,self.blocknum),np.float64, order='F'),
@ -277,6 +300,10 @@ class PSICOMP_RBF_GPU(PSICOMP_RBF):
self.gpuCache['Z_gpu'].set(np.asfortranarray(Z)) self.gpuCache['Z_gpu'].set(np.asfortranarray(Z))
self.gpuCache['mu_gpu'].set(np.asfortranarray(mu)) self.gpuCache['mu_gpu'].set(np.asfortranarray(mu))
self.gpuCache['S_gpu'].set(np.asfortranarray(S)) self.gpuCache['S_gpu'].set(np.asfortranarray(S))
N,Q = self.gpuCache['S_gpu'].shape
# t=self.g_compDenom(self.gpuCache['log_denom1_gpu'],self.gpuCache['log_denom2_gpu'],self.gpuCache['l_gpu'],self.gpuCache['S_gpu'], np.int32(N), np.int32(Q), block=(self.threadnum,1,1), grid=(self.blocknum,1),time_kernel=True)
# print 'g_compDenom '+str(t)
self.g_compDenom.prepared_call((self.blocknum,1),(self.threadnum,1,1), self.gpuCache['log_denom1_gpu'].gpudata,self.gpuCache['log_denom2_gpu'].gpudata,self.gpuCache['l_gpu'].gpudata,self.gpuCache['S_gpu'].gpudata, np.int32(N), np.int32(Q))
def reset_derivative(self): def reset_derivative(self):
self.gpuCache['dvar_gpu'].fill(0.) self.gpuCache['dvar_gpu'].fill(0.)
@ -309,11 +336,17 @@ class PSICOMP_RBF_GPU(PSICOMP_RBF):
Z_gpu = self.gpuCache['Z_gpu'] Z_gpu = self.gpuCache['Z_gpu']
mu_gpu = self.gpuCache['mu_gpu'] mu_gpu = self.gpuCache['mu_gpu']
S_gpu = self.gpuCache['S_gpu'] S_gpu = self.gpuCache['S_gpu']
log_denom1_gpu = self.gpuCache['log_denom1_gpu']
log_denom2_gpu = self.gpuCache['log_denom2_gpu']
psi0 = np.empty((N,)) psi0 = np.empty((N,))
psi0[:] = variance psi0[:] = variance
self.g_psi1computations(psi1_gpu, np.float64(variance),l_gpu,Z_gpu,mu_gpu,S_gpu, np.int32(N), np.int32(M), np.int32(Q), block=(self.threadnum,1,1), grid=(self.blocknum,1)) self.g_psi1computations.prepared_call((self.blocknum,1),(self.threadnum,1,1),psi1_gpu.gpudata, log_denom1_gpu.gpudata, np.float64(variance),l_gpu.gpudata,Z_gpu.gpudata,mu_gpu.gpudata,S_gpu.gpudata, np.int32(N), np.int32(M), np.int32(Q))
self.g_psi2computations(psi2_gpu, psi2n_gpu, np.float64(variance),l_gpu,Z_gpu,mu_gpu,S_gpu, np.int32(N), np.int32(M), np.int32(Q), block=(self.threadnum,1,1), grid=(self.blocknum,1)) self.g_psi2computations.prepared_call((self.blocknum,1),(self.threadnum,1,1),psi2_gpu.gpudata, psi2n_gpu.gpudata, log_denom2_gpu.gpudata, np.float64(variance),l_gpu.gpudata,Z_gpu.gpudata,mu_gpu.gpudata,S_gpu.gpudata, np.int32(N), np.int32(M), np.int32(Q))
# t = self.g_psi1computations(psi1_gpu, log_denom1_gpu, np.float64(variance),l_gpu,Z_gpu,mu_gpu,S_gpu, np.int32(N), np.int32(M), np.int32(Q), block=(self.threadnum,1,1), grid=(self.blocknum,1),time_kernel=True)
# print 'g_psi1computations '+str(t)
# t = self.g_psi2computations(psi2_gpu, psi2n_gpu, log_denom2_gpu, np.float64(variance),l_gpu,Z_gpu,mu_gpu,S_gpu, np.int32(N), np.int32(M), np.int32(Q), block=(self.threadnum,1,1), grid=(self.blocknum,1),time_kernel=True)
# print 'g_psi2computations '+str(t)
if self.GPU_direct: if self.GPU_direct:
return psi0, psi1_gpu, psi2_gpu return psi0, psi1_gpu, psi2_gpu
@ -352,8 +385,12 @@ class PSICOMP_RBF_GPU(PSICOMP_RBF):
dL_dpsi0_sum = dL_dpsi0.sum() dL_dpsi0_sum = dL_dpsi0.sum()
self.reset_derivative() self.reset_derivative()
self.g_psi1compDer(dvar_gpu,dl_gpu,dZ_gpu,dmu_gpu,dS_gpu,dL_dpsi1_gpu,psi1_gpu, np.float64(variance),l_gpu,Z_gpu,mu_gpu,S_gpu, np.int32(N), np.int32(M), np.int32(Q), block=(self.threadnum,1,1), grid=(self.blocknum,1)) # t=self.g_psi1compDer(dvar_gpu,dl_gpu,dZ_gpu,dmu_gpu,dS_gpu,dL_dpsi1_gpu,psi1_gpu, np.float64(variance),l_gpu,Z_gpu,mu_gpu,S_gpu, np.int32(N), np.int32(M), np.int32(Q), block=(self.threadnum,1,1), grid=(self.blocknum,1),time_kernel=True)
self.g_psi2compDer(dvar_gpu,dl_gpu,dZ_gpu,dmu_gpu,dS_gpu,dL_dpsi2_gpu,psi2n_gpu, np.float64(variance),l_gpu,Z_gpu,mu_gpu,S_gpu, np.int32(N), np.int32(M), np.int32(Q), block=(self.threadnum,1,1), grid=(self.blocknum,1)) # print 'g_psi1compDer '+str(t)
# t=self.g_psi2compDer(dvar_gpu,dl_gpu,dZ_gpu,dmu_gpu,dS_gpu,dL_dpsi2_gpu,psi2n_gpu, np.float64(variance),l_gpu,Z_gpu,mu_gpu,S_gpu, np.int32(N), np.int32(M), np.int32(Q), block=(self.threadnum,1,1), grid=(self.blocknum,1),time_kernel=True)
# print 'g_psi2compDer '+str(t)
self.g_psi1compDer.prepared_call((self.blocknum,1),(self.threadnum,1,1),dvar_gpu.gpudata,dl_gpu.gpudata,dZ_gpu.gpudata,dmu_gpu.gpudata,dS_gpu.gpudata,dL_dpsi1_gpu.gpudata,psi1_gpu.gpudata, np.float64(variance),l_gpu.gpudata,Z_gpu.gpudata,mu_gpu.gpudata,S_gpu.gpudata, np.int32(N), np.int32(M), np.int32(Q))
self.g_psi2compDer.prepared_call((self.blocknum,1),(self.threadnum,1,1),dvar_gpu.gpudata,dl_gpu.gpudata,dZ_gpu.gpudata,dmu_gpu.gpudata,dS_gpu.gpudata,dL_dpsi2_gpu.gpudata,psi2n_gpu.gpudata, np.float64(variance),l_gpu.gpudata,Z_gpu.gpudata,mu_gpu.gpudata,S_gpu.gpudata, np.int32(N), np.int32(M), np.int32(Q))
dL_dvar = dL_dpsi0_sum + gpuarray.sum(dvar_gpu).get() dL_dvar = dL_dpsi0_sum + gpuarray.sum(dvar_gpu).get()
sum_axis(grad_mu_gpu,dmu_gpu,N*Q,self.blocknum) sum_axis(grad_mu_gpu,dmu_gpu,N*Q,self.blocknum)

View file

@ -79,7 +79,6 @@ class Stationary(Kern):
#a convenience function, so we can cache dK_dr #a convenience function, so we can cache dK_dr
return self.dK_dr(self._scaled_dist(X, X2)) return self.dK_dr(self._scaled_dist(X, X2))
@Cache_this(limit=5, ignore_args=(0,))
def _unscaled_dist(self, X, X2=None): def _unscaled_dist(self, X, X2=None):
""" """
Compute the Euclidean distance between each row of X and X2, or between Compute the Euclidean distance between each row of X and X2, or between

View file

@ -30,13 +30,13 @@ class Cacher(object):
self.cached_outputs = {} # point from cache_ids to outputs self.cached_outputs = {} # point from cache_ids to outputs
self.inputs_changed = {} # point from cache_ids to bools self.inputs_changed = {} # point from cache_ids to bools
def combine_inputs(self, args, kw): def combine_inputs(self, args, kw, ignore_args):
"Combines the args and kw in a unique way, such that ordering of kwargs does not lead to recompute" "Combines the args and kw in a unique way, such that ordering of kwargs does not lead to recompute"
return args + tuple(c[1] for c in sorted(kw.items(), key=lambda x: x[0])) return tuple(a for i,a in enumerate(args) if i not in ignore_args) + tuple(c[1] for c in sorted(kw.items(), key=lambda x: x[0]))
def prepare_cache_id(self, combined_args_kw, ignore_args): def prepare_cache_id(self, combined_args_kw):
"get the cacheid (conc. string of argument ids in order) ignoring ignore_args" "get the cacheid (conc. string of argument ids in order) ignoring ignore_args"
return "".join(str(id(a)) for i,a in enumerate(combined_args_kw) if i not in ignore_args) return "".join(str(id(a)) for a in combined_args_kw)
def ensure_cache_length(self, cache_id): def ensure_cache_length(self, cache_id):
"Ensures the cache is within its limits and has one place free" "Ensures the cache is within its limits and has one place free"
@ -45,6 +45,8 @@ class Cacher(object):
cache_id = self.order.popleft() cache_id = self.order.popleft()
combined_args_kw = self.cached_inputs[cache_id] combined_args_kw = self.cached_inputs[cache_id]
for ind in combined_args_kw: for ind in combined_args_kw:
if ind is None:
continue
ind_id = id(ind) ind_id = id(ind)
ref, cache_ids = self.cached_input_ids[ind_id] ref, cache_ids = self.cached_input_ids[ind_id]
if len(cache_ids) == 1 and ref() is not None: if len(cache_ids) == 1 and ref() is not None:
@ -63,6 +65,8 @@ class Cacher(object):
self.order.append(cache_id) self.order.append(cache_id)
self.cached_inputs[cache_id] = combined_args_kw self.cached_inputs[cache_id] = combined_args_kw
for a in combined_args_kw: for a in combined_args_kw:
if a is None:
continue
ind_id = id(a) ind_id = id(a)
v = self.cached_input_ids.get(ind_id, [weakref.ref(a), []]) v = self.cached_input_ids.get(ind_id, [weakref.ref(a), []])
v[1].append(cache_id) v[1].append(cache_id)
@ -82,11 +86,12 @@ class Cacher(object):
return self.operation(*args, **kw) return self.operation(*args, **kw)
# 2: prepare_cache_id and get the unique id string for this call # 2: prepare_cache_id and get the unique id string for this call
inputs = self.combine_inputs(args, kw) inputs = self.combine_inputs(args, kw, self.ignore_args)
cache_id = self.prepare_cache_id(inputs, self.ignore_args) cache_id = self.prepare_cache_id(inputs)
# 2: if anything is not cachable, we will just return the operation, without caching # 2: if anything is not cachable, we will just return the operation, without caching
if reduce(lambda a,b: a or (not isinstance(b, Observable)), inputs, False): if reduce(lambda a,b: a or (not (isinstance(b, Observable) or (b is None))), inputs, False):
# print '['+self.operation.__name__+'] contain un-cachable arguments!'
return self.operation(*args, **kw) return self.operation(*args, **kw)
# 3&4: check whether this cache_id has been cached, then has it changed? # 3&4: check whether this cache_id has been cached, then has it changed?
try: try: