#!/usr/bin/env python3
from tkinter import *
from scipy.integrate import odeint
from numpy.linalg import norm
import numpy as np
import time
# ............................................................. Lists
QtKey=['name','m','q','fr','x','y','vx','vy']
QtLab=['Name','Mass','Charge','\u03B7','x','y','vx','vy']
nQt=len(QtKey)
QtForm=['{:s}']+['{:.3e}']*(nQt-1)
PrLab=['dt','\u03C4/ms','Trail length']
PrForm=['{:.3e}','{:d}','{:d}']
# .................................................. Global Variables
RunAll=True
GetPr=GetQt=RunMotion=False
# ................................................... Physical values
q=1.602176e-19        # elementary charge/Coulomb
me=9.10938e-31        # electron mass/kg
mp=1.67262e-27        # proton mass/kg
ke=8.98755179e9       # Coulomb's constant (N m**2/C**2)
r2=1.0e-10            # radius of second-electron orbit / m
r1=r2/3.0
v1=np.sqrt(ke*2.0*q**2/(me*r1)) # m/s
v2=np.sqrt(ke*q**2/(me*r2))     # m/s
dt=2.0e-19            # s
# .................................. Drawing and Animation Parameters
tau=10                # ms
scale=3.0e12          # px/m
cw=ch=900             # px
Ox=cw/2.0
Oy=ch/2.0
bcrad=2               # px
TrailLength=400
SelP=SelPr=SelQt=0
# ......................................................... Parameters
param=[dt,tau,TrailLength]
# ..................................................... Class Particle
class particle:
  def __init__(self,name,mass,charge,frict,x,y,vx,vy):
    self.name=name
    self.m=mass
    self.q=charge
    self.fr=frict
    self.x=x
    self.y=y
    self.vx=vx
    self.vy=vy
    if self.q>0:            # nucleus
      self.col='blue'
      self.rad=8
    else:                   # electron
      self.col='red'
      self.rad=4
    self.image=canvas.create_oval(cvx(self.x)-self.rad,\
      cvy(self.y)+self.rad,cvx(self.x)+self.rad,\
        cvy(self.y)-self.rad,fill=self.col,outline=self.col)
    self.trail=[self.x,self.y]*TrailLength
    self.ScaledTrail=[cvx(self.x),cvy(self.y)]*TrailLength
    self.TrailImg=canvas.create_line(self.ScaledTrail,fill=self.col)
  # ................................................... Move Particle
  def move(self):
    canvas.coords(self.image,cvx(self.x)-self.rad,\
      cvy(self.y)+self.rad,cvx(self.x)+self.rad,\
        cvy(self.y)-self.rad)
    if norm([self.x-self.trail[-2],self.y-self.trail[-1]])*scale>10:
      del self.trail[:2]
      self.trail.extend([self.x,self.y])
      del self.ScaledTrail[:2]
      self.ScaledTrail.extend([cvx(self.x),cvy(self.y)])
    canvas.coords(self.TrailImg,self.ScaledTrail)
# ............................................................... cvx
def cvx(x):
  global Ox,scale
  return Ox+scale*x
# ............................................................... cvy
def cvy(x):
  global Oy,scale
  return Oy-scale*x
# ............................................... Start/Stop Function
def StartStop():
  global RunMotion
  RunMotion=not RunMotion
  if RunMotion:
    butt[0]['text']='Stop'
    for b in butt[1:]+QtEntry+PrEntry:
      b['state']=DISABLED
  else:
    butt[0]['text']='Restart'
    for b in butt[1:]+QtEntry+PrEntry:
      b['state']=NORMAL
# ..................................................... Exit Function
def StopAll():
  global RunAll
  RunAll=False
# ........................................... Read Particol Variables
def ReadQt(WhichEntry):
  global GetQt,SelQt
  SelQt=WhichEntry
  GetQt=True  
# ................................................... Read Parameters
def ReadPr(WhichEntry):
  global GetPr,SelPr
  SelPr=WhichEntry
  GetPr=True
# ................................................... Select Particle
def SelectPart(delta):
  global SelP
  SelP=(SelP+delta)%nP
  SelLab.config(text=part[SelP].name,fg=part[SelP].col)
  for i,vv in enumerate(list(part[SelP].__dict__.values())[1:nQt]):
    QtEntry[i].delete(0,'end')
    QtEntry[i].insert(0,QtForm[i+1].format(vv))
# ........................................................ Scale Down
def ScaleUpDown(ud):
  global scale
  scale*=np.sqrt(2)**ud
  ScaleLab['text']='{:10.3e}'.format(scale)
  for p in part:
    p.ScaledTrail[::2]=[cvx(x) for x in p.trail[::2]]
    p.ScaledTrail[1::2]=[cvy(y) for y in p.trail[1::2]]
# .......................... Evaluate Center of Mass and its Velocity
def baryc(part):
  mtot=sum(p.m for p in part)
  cx=sum(p.x*p.m for p in part)/mtot
  cy=sum(p.y*p.m for p in part)/mtot
  cvx=sum(p.vx*p.m for p in part)/mtot
  cvy=sum(p.vy*p.m for p in part)/mtot
  return [cx,cy,cvx,cvy]
# ..................................... Move Origin to Center of Mass
def SetBaryc():
  global posvect,velvect
  xcm,ycm,cvx,cvy=baryc(part)
  for i,p in enumerate(part):
    p.x-=xcm
    p.y-=ycm
    p.vx-=cvx
    p.vy-=cvy
    posvect[2*i:2*i+2]=[p.x,p.y]
    velvect[2*i:2*i+2]=[p.vx,p.vy]
  SelectPart(0)
# ....................................................... Acceleration
def accel(posvect,velvect):
  x=posvect[::2]
  y=posvect[1::2]
  vx=velvect[::2]
  vy=velvect[1::2]
  distx=x-(np.tile(x,(len(x),1))).T
  disty=y-(np.tile(y,(len(y),1))).T
  alpha=np.arctan2(disty,distx)
  r2=np.square(distx)+np.square(disty)
  np.fill_diagonal(r2,1.0)
  q2=-ke*(np.tile(charge,(len(charge),1)).T*charge)
  ff=np.divide(q2,r2)
  np.fill_diagonal(ff,0.0)
  fx=ff*np.cos(alpha)
  fy=ff*np.sin(alpha)
  accvect=[0]*len(posvect)
  accvect[::2]=(fx.sum(axis=1)-(vx*frict))/mass
  accvect[1::2]=(fy.sum(axis=1)-(vy*frict))/mass
  return np.array(accvect)
# .............................................................. dfdt
def dfdt(InpVect,t):
  pos=InpVect[:2*nP]
  vel=InpVect[2*nP:]
  acc=accel(pos,vel)
  return np.append(vel,acc)
# ....................................................... Root Window
root=Tk()
root.title('Classical Helium Atom (odeint)')
root.bind('<Control-plus>',lambda event,num=1:ScaleUpDown(num))
root.bind('<Control-minus>',lambda event,num=-1:ScaleUpDown(num))
# ............................................................ canvas
canvas=Canvas(root,width=cw,height=ch,background='#ffffff')
canvas.grid(row=0,column=0)
# ........................................................... Toolbar
toolbar=Frame(root)
toolbar.grid(row=0,column=1,sticky=N)
toolbar.option_add('*Font','Helvetica 11')
# ............................................................ Buttons
nr=0
butt=[]
ButtLab=['Start','Set Barycenter','Exit']
ButtComm=[StartStop,SetBaryc,StopAll]
for i,(ll,cc) in enumerate(zip(ButtLab,ButtComm)):
  butt.append(Button(toolbar,text=ll,command=cc,width=11))
  butt[i].grid(row=nr,column=0,sticky=W)
  nr+=1
# .................................................. Create Particles
part=[]
part.append(particle('Nucleus',4.0*mp,2.0*q,0.0,0.0,0.0,0.0,0.0))
part.append(particle('Electron 1',me,-q,0.0,-r1,0.0,0.0,-v1))
part.append(particle('Electron 2',me,-q,0.0,r2,0.0,0.0,v2))
nP=len(part)
# ....................................... Position and Velocity Lists
posvect=np.array([0.0,0.0]*nP)
velvect=np.array([0.0,0.0]*nP)
charge=np.array([0.0]*nP)
mass=np.array([0.0]*nP)
frict=np.array([0.0]*nP)
for i,p in enumerate(part):
  posvect[2*i:2*i+2]=[p.x,p.y]
  velvect[2*i:2*i+2]=[p.vx,p.vy]
  charge[i]=p.q
  mass[i]=p.m
  frict[i]=p.fr
# ........................................... Selected Particle Label
Label(toolbar,text='Selected Particle:',pady=20).grid(row=nr,column=0)
SelLab=Label(toolbar,text=part[0].name,width=15,bg='#ffffff')
SelLab.grid(row=nr,column=1)
SelLab.bind('<Button-5>',lambda event,num=-1:SelectPart(num))
SelLab.bind('<Button-4>',lambda event,num=1:SelectPart(num))
nr+=1
# .............. Entries for Physical Quantities of Selected Particle
QtEntry=[]
for i,kk in enumerate(QtKey[1:]):
  Label(toolbar,text=QtLab[i+1]).grid(row=nr,column=0)
  QtEntry.append(Entry(toolbar,bd=3,width=16))
  QtEntry[i].grid(row=nr,column=1)
  QtEntry[i].insert(0,'{:.3e}'.format(part[0].__dict__[str(kk)]))
  QtEntry[i].bind('<Return>',lambda event,num=i:ReadQt(num))
  nr+=1
# ......................................................... Separator
Label(toolbar,text='  ').grid(row=nr,column=0)
nr+=1
# ............................................ Entries for Parameters
PrEntry=[]
for i,pl in enumerate(PrLab):
  Label(toolbar,text=pl).grid(row=nr,column=0)
  PrEntry.append(Entry(toolbar,bd=3,width=16))
  PrEntry[i].grid(row=nr,column=1)
  PrEntry[i].insert(0,PrForm[i].format(param[i]))
  PrEntry[i].bind('<Return>',lambda event,num=i:ReadPr(num))
  nr+=1
# ....................................................... Cycle Label
Label(toolbar,text='Period:',).grid(row=nr,column=0)
CycleLab=Label(toolbar,text='     ')
CycleLab.grid(row=nr,column=1,sticky=W)
nr+=1
# ....................................................... Scale Label
Label(toolbar,text='Scale:').grid(row=nr,column=0)
ScaleLab=Label(toolbar,text='{:10.3e}'.format(scale))
ScaleLab.grid(row=nr,column=1,sticky=W)
nr+=1
# ................................................... Initialize Time
tt0=time.time()
tcount=0
t=[0,dt]
# .............................................. Draw Coordinate Axes
canvas.create_line(0,Oy,cw,Oy,fill='black')
canvas.create_line(Ox,0,Ox,ch,fill='black')
# ................................. Create Barycenter Image on Canvas
bc=canvas.create_oval(Ox-bcrad,Oy-bcrad,Ox+bcrad,Oy+bcrad,fill='black')
# .................................................... Animation Loop
while RunAll:
  StartIter=time.time()
  # .................................................. Draw Particles
  for p in part:
    p.move()
  # ................................................. Draw Barycenter
  cx,cy=baryc(part)[:2]
  canvas.coords(bc,cvx(cx)-bcrad,cvy(cy)-bcrad,cvx(cx)+bcrad,\
    cvy(cy)+bcrad)
  canvas.update()
  # .......................................................... motion
  if RunMotion:
    # ................................................... Call odeint
    psoln=odeint(dfdt,np.append(posvect,velvect),t)
    posvect=psoln[1,:][:2*nP]
    velvect=psoln[1,:][2*nP:]
    for i,p in enumerate(part):
      p.x,p.y=posvect[2*i:2*i+2]
      p.vx,p.vy=velvect[2*i:2*i+2]
  elif GetPr: # ................................... Read Parameters
    try:
      vv=float(PrEntry[SelPr].get())
    except ValueError:
      pass
    else:
      if SelPr==1 or SelPr==2:
        vv=int(vv)
      param[SelPr]=vv
      PrEntry[SelPr].delete(0,'end')
      PrEntry[SelPr].insert(0,PrForm[SelPr].format(vv))
      # ...................................... If TrailLength Changed
      dTrail=param[2]-TrailLength
      if dTrail<0:
        for p in part:
          del p.trail[:2*abs(dTrail)]
          del p.ScaledTrail[:2*abs(dTrail)]
      elif dTrail>0:
        for p in part:
          NewPoints=[p.trail[0],p.trail[1]]*dTrail
          p.trail=NewPoints+p.trail
          NewPoints=[p.ScaledTrail[0],p.ScaledTrail[1]]*dTrail
          p.ScaledTrail=NewPoints+p.ScaledTrail
      # .............................................................
      dt,tau,TrailLength=param
      t=[0,dt]
    GetPr=False
  elif GetQt:  # ............................ Read Particle Variables
    try:
      part[SelP].__dict__[str(QtKey[SelQt+1])]=vv=\
        float(QtEntry[SelQt].get())
    except ValueError:
      pass
    else:
      QtEntry[SelQt].delete(0,'end')
      QtEntry[SelQt].insert(0,QtForm[SelQt].format(vv))
      posvect[2*SelP:2*SelP+2]=[part[SelP].x,part[SelP].y]
      velvect[2*SelP:2*SelP+2]=[part[SelP].vx,part[SelP].vy]
      charge[SelP]=part[SelP].q
      mass[SelP]=part[SelP].m
      frict[SelP]=part[SelP].fr
    GetQt=False
  # ................................................ Cycle Duration
  tcount+=1
  if tcount>=10:
    tcount=0
    ttt=time.time()
    elapsed=ttt-tt0
    CycleLab['text']='{:8.3f}'.format(elapsed*100.0)+' ms'
    tt0=ttt
  ElapsIter=int((time.time()-StartIter)*1000.0)
  canvas.after(tau-ElapsIter)
#--------------------------------------------------------------- Exit
root.destroy()
  
