import numpy as np
import torch
import matplotlib.pyplot as plt
import pandas as pd
from WGAN import Generator
from torch.autograd import Variable
import os


#####################
class Object(object):
    pass

args = Object()
args.feature_dim = 39
args.latent_dim = 50
args.label_dim = 161

G = Generator(label_dim=args.label_dim, latent_dim=args.latent_dim, feature_dim=args.feature_dim)
G.load_state_dict(torch.load('generator_MP2020.pt', map_location=torch.device('cpu')))
G.eval()

def sample_generator(G, num_samples, feature):
    generated_data_all = 0
    num_sam = 100
    for i in range(num_sam):
        latent_samples = Variable(G.sample_latent(num_samples))
        latent_samples = latent_samples
        generated_data = G(torch.cat((feature, latent_samples), dim=1))
        generated_data_all += generated_data
    generated_data = generated_data_all/num_sam
    return generated_data
#####################

data_pd = pd.read_csv(f'new_ternary_compositions.csv', index_col=0, keep_default_na=False)
data = data_pd.to_numpy()

columns_name = data_pd.columns
#print(len(columns_name))


ele_names = np.array([ele.split('.')[0] for ele in columns_name[1:40]])
#sys_names = np.array(list(columns_name[39:108]))
#fom_names = np.array(list(columns_name[40:50]))
N = data.shape[0]

#dict_dtypes = {x : 'str'  for x in sys_names}
#data_pd = pd.read_csv(f'Predict_3cation_testtrain_69ternaries.csv', index_col=0, keep_default_na=False)
#data = data_pd.to_numpy()
#print(data[:,32][7030])
#print(N)
#exit()

element_comps = data[:,1:40].astype(np.float32)
foms = data[:,40:50].astype(np.float32)

#element31_fom10 = np.concatenate((element_comps, foms), axis=1)
#np.save('element39_fom10.npy', element31_fom10)
test_idx = []

def get_dict():
    data_dict = {}
    #data_dict['all_element_name'] = ele_names # composition follows the ordering in this vector
    for i in range(N):
        #print(i)
        ele_comp = element_comps[i]
        fom = foms[i]
        #magpie_fea = data[i, -145:]  # 145

        gen_feat = sample_generator(G, 1, torch.from_numpy(ele_comp).unsqueeze(0)).detach().squeeze().numpy()
        #print(fom.shape,ele_comp.shape,gen_feat.shape)
        #exit()
        # print(len(magpie_fea))
        # if np.abs(1-np.sum(ele_comp)) > 0:
        #    print(np.abs(1-np.sum(ele_comp)))
        #    print(ele_comp)

        assert np.abs(1 - np.sum(ele_comp)) < 1e-5
        nonzero_idx = np.nonzero(ele_comp)
        
        #print(len(nonzero_idx))
        #print(ele_comp)
        #print(nonzero_idx[0])
        #print(ele_comp[nonzero_idx])
        #exit()
        if len(nonzero_idx[0]) > 1:
            test_idx.append(i)
            print(i)
            
        assert np.abs(1 - np.sum(ele_comp[nonzero_idx])) < 1e-5
        # print(ele_comp[nonzero_idx])
        # print(ele_name[nonzero_idx])
        data_dict[i] = {}
        data_dict[i]['fom'] = fom
        data_dict[i]['composition_nonzero'] = ele_comp[nonzero_idx] / np.sum(ele_comp[nonzero_idx])
        data_dict[i]['composition_nonzero_idx'] = nonzero_idx
        data_dict[i]['nonzero_element_name'] = ele_names[nonzero_idx]
        #print(ele_names[nonzero_idx])
        #data_dict[i]['magpie_fea'] = magpie_fea  # 145
        data_dict[i]['gen_dos_fea'] = gen_feat  # 161
        data_dict[i]['composition'] = ele_comp / np.sum(ele_comp)  # 31

    torch.save(data_dict, 'uvis_dict.chkpt')
    np.save('test_idx.npy', np.array(test_idx))

get_dict()









