#! /usr/bin/env python
# ===================================================================
# file:   model.py
# author: k.koepernik@ifw-dresden.de
# date:   15 Okt 2016

# FPLO band structure files using low level routines
# based on a model.
from __future__ import (print_function)

import sys
# You can set pyfplo path explicitely:
#sys.path.insert(0,"/home/magru/FPLO/FPLO22.00-62/PYTHON");

import os
import pyfplo.slabify as sla
import numpy as np
import numpy.linalg as LA

# ===================================================================
#  a simple model Hamiltonain:
#    2-atom chain, first atom s, second s and p orbitals
#    hopping=1
#    onsite energy difference=delta
# ===================================================================

def ham(x,delta):
    return np.array([[  -delta/2. ,2*np.cos(x), -2*np.sin(x)],
                     [2*np.cos(x) ,  delta/2. ,      0      ],
                     [-2*np.sin(x),    0      ,      0      ]
    ],'float')


# ===================================================================
# main code
# ===================================================================

print ('\npyfplo version {}\n'.format(sla.version))


delta=2 # onsite energy difference
nw=ham(0,delta).shape[0] # dimension of Hamiltonian




# setup BZ path
bp=sla.BandPlot()
bp.points=[
        ['-$~p',[-np.pi,0,0]],
        ['$~p'  ,[ np.pi,0,0]],
]
bp.ndiv=100
bp.calculateBandPlotMesh('.')
dists=bp.kdists
kpts=bp.kpnts




# some band weight labels
labels=['atom1 orb{}'.format(i+1) for i in range(0,1)]
labels.extend(['atom2 orb{}'.format(i+1) for i in range(1,nw) ])

print ('labels are',labels,'\n')

# Open bandplot and bandweight  file.
# We use mixed semantics: open and close fb explicitely
#                         open fw in with-statement which closes fw implicitely
fb=bp.openBandFile('+b',1,len(kpts))
with bp.openBandFile('+w',1,len(kpts),weightlabels=labels,
                   progress='bandplot') as fw:
    # write data
    for dk,k in zip(dists,kpts):
        Hk=ham(k[0],delta) # use k[0], since ham(x,delta) takes scalar not vector
        (EV,C)=LA.eigh(Hk) # diagonalize
        fw.write(0,dk,k,EV,np.absolute(C)**2) # write weights
        fb.write(0,dk,k,EV) # write band file
fb.close()
# Now we have +w (and +b)






# define some weights for addweight
wds=sla.WeightDefinitions()
w=wds.add('atom1') # sum of all atom1 weights
w.addLabels(labels[0:1]) # take label 0
w=wds.add('atom2') # sum of all atom2 weights
w.addLabels(labels[1:3]) # take label 1,2

# read a bandweights file 
bw=sla.BandWeights("+w")
# and add weights and write to +bwsum
bw.addWeights(wds,'+bwsum',vlevel=sla.Vlevel.All)




# for illustration we read the files into numpy arrays
# First the band file:

[bh,dists2,kptns2,erg]=bp.readBands('+b')
print( 'bandheader +b: nkp={0} nband={1} nspin={2}'
       .format(bh.nkp,bh.nband,bh.nspin))
print( 'kdists2 shape :',dists2.shape)
print( 'kptns2 shape :',kptns2.shape)
print( 'erg shape :',erg.shape)
# write is to another file
with bp.openBandFile('+b2',1,len(kpts)) as fb:
    for i in range(len(dists)):
        dk=dists2[i]
        k=kptns2[i,:] # kptns2 is c-ordered, last dimension is size 3
        fb.write(0,dk,k,erg[0,:,i])

err=''
if os.system('diff +b +b2>/dev/null')!=0:
    err+='\nSomething went wrong: the two files +b and +b2 should be equal\n'
else:
    print( '\nThe two files +b and +b2 are equal. Good.\n')


    
# now read bandweights
bw=sla.BandWeights('+bwsum')
[bh,dists3,erg2,wei]=bw.readBandWeights()
print( 'bandheader +bwsum: nkp={0} norb={1} nband={2} nspin={3}\n\
                   labels={4}'
       .format(bh.nkp,bh.norb,bh.nband,bh.nspin,bh.labels))

print( 'weights shape :',wei.shape)

# write is to another file
with bp.openBandFile('+w2',1,len(kpts),bh.labels) as fb:
    for i in range(len(dists)):
        dk=dists3[i]
        k=[0,0,0] # we do not have kpnts in weight files, but we do not
        # need it either (at least here)
        # If you really need it read the corresponding ban file
        # to get kpnts.
        fb.write(0,dk,k,erg2[0,:,i],wei[0,:,:,i].T)


if os.system('diff +bwsum +w2>/dev/null')!=0:
    err+='\nSomething went wrong: the two files +bwsum and +w2 should be equal\n'
else:
    print( '\nThe two files +bwsum and +w2 are equal. Good.\n')










    
# simple plotting script
with open('show_w.cmd','w') as f:
    f.write('''kill all
read bandweight "+w"

with g1.gr1
weight style individual
weight factor 2
w1 on
skip 3
symbol fill off
symbol line width 2
symbol style square
w2 on
skip 3
symbol fill off
symbol line width 2
symbol style square
w3 on
skip 3
symbol fill off
symbol line width 2
symbol style square
title "2-atom 3-orbital chain"

vx0=0.126565
vy0=0.16384
vw=0.32
vh=0.655041
view vx0,vy0,vw,vh
autoscale offset 0,0.1
autoscale

############################################################

read bandweight "+bwsum" into g2

with g2
vx0=0.626565
view vx0,vy0,vw,vh

with gr1
weight style individual
weight factor 2
w1 on
color 0xcccc00
skip 3
symbol fill off
symbol line width 2
symbol style circle
w2 on
color 0xaaaa
skip 3
symbol fill off
symbol line width 2
symbol style diamond
title "2-atom 3-orbital chain"

autoscale offset 0,0.1
autoscale
''')

# simple plotting script
with open('show_w.xpy','w') as f:
    f.write('''
killall()
g=G[1]


#============================================================
def draw(ng,fname,nw,shape):
    g=G[ng]
    gr=g.read('bandweight',fname)[0].group
    gr.setWeightsStyle(style='individual',factor=2,min=0.2,max=6,
        showinlegend=True)
    
    
    cols=[0xaa0000,0xaa00,0xff]
    if  ng==2: 
        cols[1]=0xaaaa
    
    for i in range(nw):
        sh=shape
        if i==0 and ng==2: sh='o'
        w=gr.W[i+1]
        w.on()
        w.setStyle(style=sh,color=cols[i],skip=3,linewidth=2,
            fill=False)
    
    
    vx0=0.126565
    vy0=0.16384
    vw=0.32
    vh=0.655041
    
    if ng==2: vx0=0.626565
    View(ng).setGeometry(vx0,vy0,vw,vh)
    
    if ng==1:
        g.title.text="2-atom 3-orbital chain"
    else:
        g.title.text="2-atom 3-orbital chain"
    
    g.world.offset=[0,0.1]
    g.autoscale()

#############################################################

draw(1,'+w',3,'q')
draw(2,'+bwsum',2,'d')
''')

# show it
os.system('xfbp show_w.cmd  >/dev/null')




if err:
    raise RuntimeError(err)

