import argparse,math,time,warnings,copy, numpy as np, os.path as path
from os import listdir
import pandas as pd
from os.path import isfile, join
import os
import torch
import random

uvis = torch.load('../uvis_dict.chkpt')

N = len(uvis)-1

train_path = r'train/'
test_path = r'test/'

terns = os.listdir(train_path)
all_train_idx = []
all_test_idx = []

for tern in terns:
    train_idx = np.load(train_path + tern)
    test_idx = np.load(test_path + tern)
    all_train_idx += train_idx.tolist()
    all_test_idx += test_idx.tolist()

    
set_all_train_idx = list(set(all_train_idx))
set_all_test_idx = list(set(all_test_idx))

print(len(all_train_idx), len(set_all_train_idx))
print(len(all_test_idx), len(set_all_test_idx))

set_all_train_idx_sep = []
for i in set_all_train_idx:
    if i not in set_all_test_idx:
        set_all_train_idx_sep.append(i)

set_all_test_idx = np.array(set_all_test_idx)
np.random.shuffle(set_all_test_idx)
M = len(set_all_test_idx)
train_from_test = set_all_test_idx[0:int(0.7*M)]
val_from_test = set_all_test_idx[int(0.7*M):int(0.85*M)]
test_from_test = set_all_test_idx[int(0.85*M):]

train = np.array(set_all_train_idx_sep + train_from_test.tolist())

np.save('./rd_idx_jh/train_jh.npy', train)
np.save('./rd_idx_jh/val_jh.npy', val_from_test)
np.save('./rd_idx_jh/test_jh.npy', test_from_test)

print(len(train), len(val_from_test), len(test_from_test))


idx = []

for i in range(N):
    point = uvis[i]
    point_elem = point['nonzero_element_name']
    if len(point_elem) > 1:
        idx.append(i)
        
idx = np.array(idx)    
np.random.shuffle(idx)    
train_idx = idx[0:int(0.8*N)]
val_idx = idx[int(0.8*N):int(0.9*N)]
test_idx = idx[int(0.9*N):]

np.save('./rd_idx/train_idx.npy', train_idx)
np.save('./rd_idx/val_idx.npy', val_idx)
np.save('./rd_idx/test_idx.npy', test_idx)





