# vim: ft=python
import numpy as np
from qiskit_nature.second_q.drivers.electronic_structure_driver import _QCSchemaData

# force c1 symmetry on molecule
core.get_active_molecule().reset_point_group('c1')

_q_hf_energy, _q_hf_wavefn = energy('scf', return_wfn=True)
_q_mints = MintsHelper(_q_hf_wavefn.basisset())
_q_mol   = _q_hf_wavefn.molecule()
_has_B   = not _q_hf_wavefn.same_a_b_orbs()
_q_h1 = _q_hf_wavefn.H()
_q_h1b = _q_h1.clone() if _has_B else None

data = _QCSchemaData()
data.hij = np.asarray(_q_h1.clone())
data.hij_b = None
data.eri = np.asarray(_q_mints.ao_eri())
_q_h1.transform(_q_hf_wavefn.Ca())
data.hij_mo = np.asarray(_q_h1)
data.hij_mo_b = None
if _has_B:
    _q_h1b.transform(_q_hf_wavefn.Cb())
    data.hij_mo_b = np.asarray(_q_h1b)

data.eri_mo = np.asarray(_q_mints.mo_eri(_q_hf_wavefn.Ca(), _q_hf_wavefn.Ca(),
                                         _q_hf_wavefn.Ca(), _q_hf_wavefn.Ca()))
data.eri_mo_ba = None
data.eri_mo_bb = None
if _has_B:
    data.eri_mo_bb = np.asarray(_q_mints.mo_eri(_q_hf_wavefn.Cb(), _q_hf_wavefn.Cb(),
                                                _q_hf_wavefn.Cb(), _q_hf_wavefn.Cb()))
    data.eri_mo_ba = np.asarray(_q_mints.mo_eri(_q_hf_wavefn.Cb(), _q_hf_wavefn.Cb(),
                                                _q_hf_wavefn.Ca(), _q_hf_wavefn.Ca()))

data.e_nuc = _q_mol.nuclear_repulsion_energy()
data.e_ref = _q_hf_energy
data.mo_coeff = np.asarray(_q_hf_wavefn.Ca())
data.mo_coeff_b = np.asarray(_q_hf_wavefn.Cb()) if _has_B else None
data.mo_energy = np.asarray(_q_hf_wavefn.epsilon_a())
data.mo_energy_b = np.asarray(_q_hf_wavefn.epsilon_b()) if _has_B else None
data.mo_occ = None
data.mo_occ_b = None
data.symbols = []
data.coords  = np.empty([_q_mol.natom(), 3])
for _n in range(0, _q_mol.natom()):
    data.symbols.append(_q_mol.symbol(_n))
    data.coords[_n][0] = _q_mol.x(_n)
    data.coords[_n][1] = _q_mol.y(_n)
    data.coords[_n][2] = _q_mol.z(_n)

data.coords = data.coords.flatten()

data.multiplicity = _q_mol.multiplicity()
data.charge = _q_mol.molecular_charge()
data.masses = np.asarray([_q_mol.mass(at) for at in range(_q_mol.natom())])
data.method = None
data.basis = None
data.creator = "Psi4"
data.version = psi4.__version__
data.routine = None
data.nbasis = None
data.nmo = _q_hf_wavefn.nmo()
data.nalpha = _q_hf_wavefn.nalpha()
data.nbeta = _q_hf_wavefn.nbeta()
data.keywords = None

_q_dipole = _q_mints.ao_dipole()
data.dip_x = np.asarray(_q_dipole[0])
data.dip_y = np.asarray(_q_dipole[1])
data.dip_z = np.asarray(_q_dipole[2])

for _n in range(len(_q_dipole)):
    _q_dipole[_n].transform(_q_hf_wavefn.Ca())

data.dip_mo_x_a = np.asarray(_q_dipole[0])
data.dip_mo_y_a = np.asarray(_q_dipole[1])
data.dip_mo_z_a = np.asarray(_q_dipole[2])

if _has_B:
    _q_dipole = _q_mints.ao_dipole()
    for _n in range(len(_q_dipole)):
        _q_dipole[_n].transform(_q_hf_wavefn.Cb())

    data.dip_mo_x_b = np.asarray(_q_dipole[0])
    data.dip_mo_y_b = np.asarray(_q_dipole[1])
    data.dip_mo_z_b = np.asarray(_q_dipole[2])

_q_nd = _q_mol.nuclear_dipole()
data.dip_nuc = np.array([_q_nd[0], _q_nd[1], _q_nd[2]])
data.dip_ref = _q_hf_wavefn.variable("CURRENT DIPOLE")

# _FILE_PATH variable gets injected by the driver that calls this template
data.to_hdf5(_FILE_PATH)
