# vim: ft=python
import numpy

from qiskit_nature.constants import BOHR
from qiskit_nature.drivers import Molecule
from qiskit_nature.hdf5 import save_to_hdf5
from qiskit_nature.properties.second_quantization.driver_metadata import DriverMetadata
from qiskit_nature.properties.second_quantization.electronic import (
    AngularMomentum,
    DipoleMoment,
    ElectronicDipoleMoment,
    ElectronicEnergy,
    ElectronicStructureDriverResult,
    Magnetization,
    ParticleNumber,
)
from qiskit_nature.properties.second_quantization.electronic.bases import (
    ElectronicBasis,
    ElectronicBasisTransform,
)
from qiskit_nature.properties.second_quantization.electronic.integrals import (
    OneBodyElectronicIntegrals,
    TwoBodyElectronicIntegrals,
)

# Fix _q_geometry location & orientation and force c1 symmetry on molecule
core.get_active_molecule().fix_com(True)
core.get_active_molecule().fix_orientation(True)
core.get_active_molecule().reset_point_group("c1")

_q_hf_energy, _q_hf_wavefn = energy("scf", return_wfn=True)
_q_mints = MintsHelper(_q_hf_wavefn.basisset())
_q_mol = _q_hf_wavefn.molecule()
_q_has_B = not _q_hf_wavefn.same_a_b_orbs()

_q_driver_result = ElectronicStructureDriverResult()

_q_geometry = []
for _n in range(_q_mol.natom()):
    _q_geometry.append(
        (_q_mol.symbol(_n), [_q_mol.x(_n) * BOHR, _q_mol.y(_n) * BOHR, _q_mol.z(_n) * BOHR])
    )

_q_driver_result.molecule = Molecule(_q_geometry, _q_mol.multiplicity(), _q_mol.molecular_charge())

_q_driver_result.add_property(
    DriverMetadata(
        "PSI4",
        psi4.__version__,
        "",
    )
)

_q_driver_result.add_property(
    ElectronicBasisTransform(
        ElectronicBasis.AO,
        ElectronicBasis.MO,
        numpy.asarray(_q_hf_wavefn.Ca()),
        numpy.asarray(_q_hf_wavefn.Cb()) if _q_has_B else None,
    )
)

_q_driver_result.add_property(
    ParticleNumber(
        num_spin_orbitals=_q_hf_wavefn.nmo() * 2,
        num_particles=(_q_hf_wavefn.nalpha(), _q_hf_wavefn.nbeta()),
    )
)

_q_kinetic = _q_mints.ao_kinetic()
_q_overlap = _q_mints.ao_overlap()
_q_h1 = _q_mints.ao_potential()
_q_h1.add(_q_kinetic)
_q_h1b = _q_h1.clone() if _q_has_B else None

_q_one_body_ao = OneBodyElectronicIntegrals(
    ElectronicBasis.AO,
    (numpy.asarray(_q_h1.clone()), None),
)

_q_two_body_ao = TwoBodyElectronicIntegrals(
    ElectronicBasis.AO,
    (numpy.asarray(_q_mints.ao_eri()), None, None, None),
)

_q_h1.transform(_q_hf_wavefn.Ca())
if _q_has_B:
    _q_h1b.transform(_q_hf_wavefn.Cb())

_q_one_body_mo = OneBodyElectronicIntegrals(
    ElectronicBasis.MO,
    (numpy.asarray(_q_h1.clone()), numpy.asarray(_q_h1b.clone()) if _q_has_B else None),
)

_q_two_body_mo = TwoBodyElectronicIntegrals(
    ElectronicBasis.MO,
    (
        numpy.asarray(
            _q_mints.mo_eri(
                _q_hf_wavefn.Ca(), _q_hf_wavefn.Ca(), _q_hf_wavefn.Ca(), _q_hf_wavefn.Ca()
            )
        ),
        numpy.asarray(
            _q_mints.mo_eri(
                _q_hf_wavefn.Cb(), _q_hf_wavefn.Cb(), _q_hf_wavefn.Ca(), _q_hf_wavefn.Ca()
            )
        )
        if _q_has_B
        else None,
        numpy.asarray(
            _q_mints.mo_eri(
                _q_hf_wavefn.Cb(), _q_hf_wavefn.Cb(), _q_hf_wavefn.Cb(), _q_hf_wavefn.Cb()
            )
        )
        if _q_has_B
        else None,
        None,
    ),
)

_q_electronic_energy = ElectronicEnergy(
    [_q_one_body_ao, _q_two_body_ao, _q_one_body_mo, _q_two_body_mo],
    nuclear_repulsion_energy=_q_mol.nuclear_repulsion_energy(),
    reference_energy=_q_hf_energy,
)

if _q_has_B:
    _q_orbital_energies = [
        numpy.asarray(_q_hf_wavefn.epsilon_a()),
        numpy.asarray(_q_hf_wavefn.epsilon_b()),
    ]
else:
    _q_orbital_energies = numpy.asarray(_q_hf_wavefn.epsilon_a())

_q_electronic_energy.orbital_energies = _q_orbital_energies

_q_electronic_energy.kinetic = OneBodyElectronicIntegrals(ElectronicBasis.AO, (_q_kinetic, None))
_q_electronic_energy.overlap = OneBodyElectronicIntegrals(ElectronicBasis.AO, (_q_overlap, None))

_q_driver_result.add_property(_q_electronic_energy)

_q_nuclear_dipole = _q_mol.nuclear_dipole()
_q_dipole = ElectronicDipoleMoment(
    nuclear_dipole_moment=[_q_nuclear_dipole[0], _q_nuclear_dipole[1], _q_nuclear_dipole[2]],
    reverse_dipole_sign=False,
)

for idx, axis in enumerate(["x", "y", "z"]):
    _q_dip_ao = _q_mints.ao_dipole()[idx]
    _q_dip_mo = _q_dip_ao.clone()
    _q_dip_mo.transform(_q_hf_wavefn.Ca())
    _q_dip_mo_b = None
    if _q_has_B:
        _q_dip_mo_b = _q_dip_ao.clone()
        _q_dip_mo_b.transform(_q_hf_wavefn.Cb())

    _q_dipole.add_property(
        DipoleMoment(
            axis,
            [
                OneBodyElectronicIntegrals(
                    ElectronicBasis.AO,
                    (
                        numpy.asarray(_q_dip_ao),
                        None,
                    ),
                ),
                OneBodyElectronicIntegrals(
                    ElectronicBasis.MO,
                    (
                        numpy.asarray(_q_dip_mo),
                        numpy.asarray(_q_dip_mo_b) if _q_has_B else None,
                    ),
                ),
            ],
        )
    )

_q_driver_result.add_property(_q_dipole)
