Created
October 22, 2021 02:22
-
-
Save hengck23/b7419ed56ce9f9219a8fef56bef31403 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from common import * | |
from ventilator import * | |
import faiss | |
data_dir = root_dir+'/data' | |
train_df = pd.read_csv(data_dir + '/train.csv') | |
test_df = pd.read_csv(data_dir + '/test.csv') | |
if 1: #check statistics | |
for r in [5,20,50]: | |
for c in [10,20,50]: | |
t_df = train_df[(train_df.R == r) & (train_df.C == c)] | |
t_num = len(t_df)//80 | |
t_u_out = t_df.u_out.values.reshape(-1, 80) | |
t_u_in = t_df.u_in.values.reshape(-1, 80) | |
t_p = t_df.pressure.values.reshape(-1,80) | |
t_t = t_df.time_step.values.reshape(-1,80) | |
e_df = test_df[(test_df.R == r) & (test_df.C == c)] | |
e_num = len(e_df)//80 | |
e_u_out = e_df.u_out.values.reshape(-1, 80) | |
e_u_in = e_df.u_in.values.reshape(-1, 80) | |
#e_p = e_df.pressure.values.reshape(-1,80) | |
e_t = e_df.time_step.values.reshape(-1,80) | |
#--- | |
dim = 80 # dimension | |
nb = t_num # database size | |
nq = e_num # nb of queries | |
xb = t_u_in.astype('float32') | |
xq = e_u_in.astype('float32') | |
#https://github.com/facebookresearch/faiss/blob/main/tutorial/python/4-GPU.py | |
res = faiss.StandardGpuResources() # use a single GPU | |
index_flat = faiss.IndexFlatL2(dim) | |
gpu_index_flat = faiss.index_cpu_to_gpu(res, 0, index_flat) | |
gpu_index_flat.add(xb) # add vectors to the index | |
print(gpu_index_flat.ntotal) | |
num_neighbour = 5 # we want to see 5 nearest neighbors | |
D, I = gpu_index_flat.search(xq, num_neighbour) # actual search | |
for i in range(e_num): | |
plt.clf() | |
plt.plot(e_t[i], e_u_in[i], lw=2, alpha=1, label='test u_in',c='black') | |
plt.plot(t_t[I[i,0]], t_u_in[I[i,0]], '--', lw=2, alpha=1, label='knn-1 u_in %4.1f'%D[i,0],c='green') | |
plt.plot(t_t[I[i,1]], t_u_in[I[i,1]], '--', lw=1, alpha=1, label='knn-2 u_in %4.1f'%D[i,1],c='green') | |
plt.plot(t_t[I[i,2]], t_u_in[I[i,2]], '--', lw=1, alpha=0.25, label='knn-3 u_in %4.1f'%D[i,2],c='green') | |
plt.plot(t_t[I[i,0]], t_p[I[i,0]], '--', lw=2, alpha=1, label='knn-1 press %4.1f'%D[i,0],c='red') | |
plt.plot(t_t[I[i,1]], t_p[I[i,1]], '--', lw=1, alpha=1, label='knn-2 press %4.1f'%D[i,1],c='red') | |
plt.plot(t_t[I[i,2]], t_p[I[i,2]], '--', lw=1, alpha=0.25, label='knn-3 press %4.1f'%D[i,2],c='red') | |
plt.ylim([-5, 65]) | |
plt.xlim([0, 2.5]) | |
plt.legend() | |
# plt.show() | |
while not plt.waitforbuttonpress(): pass | |
zz=0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment