import os
import sys
import ase.io
import numpy as np
from ase import Atoms, Atom
from ase.io import Trajectory
from sklearn.metrics import mean_squared_error
from dscribe.descriptors import SOAP
from scipy.interpolate import UnivariateSpline
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF
from sklearn.model_selection import train_test_split, learning_curve, validation_curve, GridSearchCV
from joblib import Parallel, delayed
from ase import neighborlist 
from sklearn.preprocessing import StandardScaler
from create_descriptor import *
from tool import *
from zopt import *

if __name__ == '__main__':
    #debut de l initialisation
    E = -808.18103420 #slab without O
    mu = -9.85391418/2  #O2 sur 2
    atomseul='H' #atome deposé
    traj='data/final.traj' 
    data_folder='data'
    poscar_file='POSCAR' #slab with no adsorbate
    os.chdir('/home/boulang31/Documents/codes/zopt')
    train_pos_indices=[1,25,36]
    atoms_train,y_train,species=get_train_data(traj,data_folder,train_pos_indices)
    xx,yy,zz=pos_to_relax(poscar_file)
    print(zz)
    params={'species':species,'l_max':2,'n_max':2,'r_cut':7}
    desc=create_descriptor(method='soap',params=params,ats=0)
    X_train=desc.create(atoms_train,load=True,save_file=data_folder)

    y_mean=np.mean(y_train)
    y_train=y_train-y_mean
    scaler=StandardScaler()
    X_train=scaler.fit_transform(X_train)  

    #machine
    kernel=1**2 * RBF(length_scale=200,length_scale_bounds=(1/np.sqrt(2*1e-0),1/np.sqrt(2*1e-8)))
    GP = GaussianProcessRegressor(kernel=kernel)
    title='GP'
    parameters = {'alpha':[1e-3,]}
    g_search = GridSearchCV(GP, parameters)
    g_search.fit(X_train, y_train)
    print('best estimator',g_search.best_estimator_.kernel_)
    np.save('GP_alpha.npy',g_search.best_estimator_.alpha_)
    #recherche mini
    #for i in range(len(xx)):
    #    x_soap,zetap,ind,zlim=searchEminzgrid(x=xx[i],y=yy[i],ind=i,z=17,posDFT=posDFT,atomseul='O')  
    #    x_soap=scaler.transform(x_soap)          
    #    Eprd,Estd=g_search.best_estimator_.predict(x_soap,return_std=True)
    #    pos_reduc(np.array(zetap),Eprd,Estd,ind,zlim,xx[i],yy[i])
          
    Parallel(n_jobs=10)(delayed(run_para)([xx[i],yy[i],zz[i]],i,desc,scaler,g_search) for i in range(len(xx)))