Skip to content

Instantly share code, notes, and snippets.

@hengck23
Created October 22, 2021 02:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save hengck23/b7419ed56ce9f9219a8fef56bef31403 to your computer and use it in GitHub Desktop.
Save hengck23/b7419ed56ce9f9219a8fef56bef31403 to your computer and use it in GitHub Desktop.
from common import *
from ventilator import *
import faiss
data_dir = root_dir+'/data'
train_df = pd.read_csv(data_dir + '/train.csv')
test_df = pd.read_csv(data_dir + '/test.csv')
if 1: #check statistics
for r in [5,20,50]:
for c in [10,20,50]:
t_df = train_df[(train_df.R == r) & (train_df.C == c)]
t_num = len(t_df)//80
t_u_out = t_df.u_out.values.reshape(-1, 80)
t_u_in = t_df.u_in.values.reshape(-1, 80)
t_p = t_df.pressure.values.reshape(-1,80)
t_t = t_df.time_step.values.reshape(-1,80)
e_df = test_df[(test_df.R == r) & (test_df.C == c)]
e_num = len(e_df)//80
e_u_out = e_df.u_out.values.reshape(-1, 80)
e_u_in = e_df.u_in.values.reshape(-1, 80)
#e_p = e_df.pressure.values.reshape(-1,80)
e_t = e_df.time_step.values.reshape(-1,80)
#---
dim = 80 # dimension
nb = t_num # database size
nq = e_num # nb of queries
xb = t_u_in.astype('float32')
xq = e_u_in.astype('float32')
#https://github.com/facebookresearch/faiss/blob/main/tutorial/python/4-GPU.py
res = faiss.StandardGpuResources() # use a single GPU
index_flat = faiss.IndexFlatL2(dim)
gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index_flat)
gpu_index_flat.add(xb) # add vectors to the index
print(gpu_index_flat.ntotal)
num_neighbour = 5 # we want to see 5 nearest neighbors
D, I = gpu_index_flat.search(xq, num_neighbour) # actual search
for i in range(e_num):
plt.clf()
plt.plot(e_t[i], e_u_in[i], lw=2, alpha=1, label='test u_in',c='black')
plt.plot(t_t[I[i,0]], t_u_in[I[i,0]], '--', lw=2, alpha=1, label='knn-1 u_in %4.1f'%D[i,0],c='green')
plt.plot(t_t[I[i,1]], t_u_in[I[i,1]], '--', lw=1, alpha=1, label='knn-2 u_in %4.1f'%D[i,1],c='green')
plt.plot(t_t[I[i,2]], t_u_in[I[i,2]], '--', lw=1, alpha=0.25, label='knn-3 u_in %4.1f'%D[i,2],c='green')
plt.plot(t_t[I[i,0]], t_p[I[i,0]], '--', lw=2, alpha=1, label='knn-1 press %4.1f'%D[i,0],c='red')
plt.plot(t_t[I[i,1]], t_p[I[i,1]], '--', lw=1, alpha=1, label='knn-2 press %4.1f'%D[i,1],c='red')
plt.plot(t_t[I[i,2]], t_p[I[i,2]], '--', lw=1, alpha=0.25, label='knn-3 press %4.1f'%D[i,2],c='red')
plt.ylim([-5, 65])
plt.xlim([0, 2.5])
plt.legend()
# plt.show()
while not plt.waitforbuttonpress(): pass
zz=0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment