"""Plot the LDA and Koopmans band structures of ZnO."""

import matplotlib.patheffects as pe
import matplotlib.pyplot as plt
import numpy as np

from koopmans.io import read

# Load the workflow
wf = read('zno.pkl')

# The Koopmans bands were generated by the very last calculation in the workflow
koopmans_calc = wf.calculations[-1]

# The LDA bands were generated by a pw.x calculation with the setting "calculation = bands"
[lda_calc] = [c for c in wf.calculations if c.parameters.get('calculation', None) == 'bands']

# Fetch the Koopmans bands, and shift them so that the valence band maximum is zero
koopmans_bs = koopmans_calc.results['band structure']
koopmans_bs_shifted = koopmans_bs.subtract_reference()

# Fetch the LDA bands, and shift them by the same amount
lda_bs = lda_calc.results['band structure']
lda_bs_shifted = lda_bs.subtract_reference(koopmans_bs.reference)

# Plot the two band structures
ax = lda_bs_shifted.plot(label='LDA', spin=0, color='tab:blue', ls='--')
ax = koopmans_bs_shifted.plot(ax=ax, label='KI@LDA', color='tab:green')

# Find the Koopmans valence band maximum
valence = koopmans_bs_shifted.energies[:, :, :-2]
i_vbm = np.unravel_index(np.nanargmax(valence), valence.shape)
x, _, _ = koopmans_bs.get_labels()
x_vbm = x[i_vbm[1]]
y_vbm = valence[i_vbm]

# Find the Koopmans conduction band minimum
conduction = koopmans_bs_shifted.energies[:, :, -2:]
i_cbm = np.unravel_index(np.nanargmin(conduction), conduction.shape)
x, _, _ = koopmans_bs.get_labels()
x_cbm = x[i_cbm[1]]
y_cbm = conduction[i_cbm]

# Label the band gap
ax.annotate(xy=(x_vbm, y_vbm), xycoords='data',
            xytext=(x_cbm, y_cbm), textcoords='data', text='',
            arrowprops={'arrowstyle': '<->', 'shrinkA': 0, 'shrinkB': 0})
ax.text((x_cbm + x_vbm) / 2 + 0.1, (y_cbm + y_vbm) / 2,
        f'{y_cbm - y_vbm:.2f} eV', ha='left', va='center',
        path_effects=[pe.withStroke(linewidth=4, foreground='white')])

# Tweak the figure aesthetics
ax.legend(loc='lower right', ncol=2, bbox_to_anchor=(1, 1))
ax.set_ylim([-10, 15])

# Display or save the figure (uncomment as desired)
plt.savefig('zno_bandstructures.png')
# plt.show()
