#!/usr/bin/env python3
from tkinter import *
from scipy.integrate import odeint,solve_ivp
from datetime import timedelta
import numpy as np
import time
# .................................................. Global variables
RunAll=True
GetCycle=GetStep=RunMotion=False
# ........................................................... Methods
EULER,VERLET,ODEINT,YOSHIDA=range(4)
method=['Euler','Verlet','odeint','Yoshida']
col=['magenta','blue','red','green','black']
ITER,ELAPSED,ORBITS,PERIOD,CYCLE,SCALE=range(6)
quant=['Iterations','Elapsed Time','Orbits','Orbital Period',\
  'Cycle','Scale']
# ................................................... Physical values
ME=5.9722e24           # Earth mass/kg
G=6.674e-11            # Gravitational constant [m^3/(kg s^2)]
GM=ME*G
dt=10.0                # s
dT=3*dt                # s
step=[dt,dt,dT,dT]
# ................................................. Yoshida Constants
Theta=1.0/(2-2**(1.0/3.0))
CX1=CX4=0.5*Theta
CX2=CX3=0.5*(1.0-Theta)
CV1=CV3=Theta
CV2=1.0-2.0*Theta
# .................................. Drawing and Animation Parameters
cycle=1                # ms
scale=1.0e-5           # px/m
cw=600                 # px
ch=500                 # px
Ox=150
Oy=ch/2.0
rad=4                  # px
ms=1.0e3               # satellite mass/kg
TrailLength=400
x0=4.0e7               # satellite apogee/m 
vx0=0.0
y0=0.0
vy0=0.5*np.sqrt(GM/x0) # ~1580 m/s
# ............................................... Start/Stop function
def StartStop():
  global RunMotion
  RunMotion=not RunMotion
  if RunMotion:
    butts[0]['text']='Stop'
    for zz in butts[2:]+TimeEntry:
      zz['state']=DISABLED
  else:
    butts[0]['text']='Restart'
    for zz in butts[2:]+TimeEntry:
      zz['state']=NORMAL
# ..................................................... Exit function
def StopAll():
  global RunAll
  RunAll=False
# ..................................................... Scale up/down
def ScaleUpDown(event,ud=0):
  global scale
  scale*=np.sqrt(2)**ud
  QtLab[SCALE]['text']='{:10.3e}'.format(scale)
# ........................................................ Read Entry
def ReadData(event,tx):
  global GetStep,GetCycle
  if tx==0:
    GetStep=True
  elif tx==1:
    GetCycle=True  
# ...................................................... acceleration  
def accel(PosVect):
  aa=-GM/np.dot(PosVect,PosVect)
  alpha=np.arctan2(PosVect[1],PosVect[0])
  ax=aa*np.cos(alpha)
  ay=aa*np.sin(alpha)
  return np.array([ax,ay])
# ............................................................ energy
def ener(InputList):
  pot=-GM*ms/np.linalg.norm(InputList[:2])
  kin=0.5*ms*np.dot(InputList[2:],InputList[2:])
  return pot+kin
# ..................................................... Time Reversal
def TimeReversal():
  global step,dt,dT
  dt=-dt
  dT=3*dt
  step=[dt,dt,dT,dT]
# ................................................... odeint Function
def dfdt(InputList,t):
  ax,ay=accel(InputList[:2])
  vx,vy=InputList[2:]
  return [vx,vy,ax,ay]
# .................................................. odeint Algorithm
def OdeintAlgo(InputList,step):
  tt=[0,step]
  psoln=odeint(dfdt,InputList,tt)
  InputList=psoln[1,:]
  return InputList
# ................................................. Yoshida Algorithm
def yoshida(InputList,h):
  InputList[:2]+=CX1*h*InputList[2:]
  InputList[2:]+=CV1*h*accel(InputList[:2])
  InputList[:2]+=CX2*h*InputList[2:]
  InputList[2:]+=CV2*h*accel(InputList[:2])
  InputList[:2]+=CX3*h*InputList[2:]
  InputList[2:]+=CV3*h*accel(InputList[:2])
  InputList[:2]+=CX4*h*InputList[2:]
  return InputList
# .................................................. Verlet Algorithm
def verlet(InputList,h):
  global VerlAccel
  for _ in range(3):
    InputList[2:]+=0.5*h*VerlAccel
    InputList[:2]+=h*InputList[2:]
    VerlAccel=accel(InputList[:2])
    InputList[2:]+=0.5*h*VerlAccel
  return InputList
# ................................................... Euler Algorithm
def euler(InputList,h):
  for _ in range(3):
    InputList[:2]+=h*InputList[2:]
    InputList[2:]+=h*accel(InputList[:2])
  return InputList
# ................................................ Canvas Coordinates
def CanvCoord(InputList):
  global Ox,Oy,scale
  x,y,vx,vy=InputList
  return [Ox+scale*x,Oy-scale*y]
# .................................................... Algorithm List
algo=[euler,verlet,OdeintAlgo,yoshida]
# ....................................................... Root Window
root=Tk()
root.title('Gravitational Orbit')
root.bind('<Control-plus>',lambda event,num=1:ScaleUpDown(None,num))
root.bind('<Control-minus>',lambda event,num=-1:ScaleUpDown(None,num))
# ............................................................ Canvas
canvas=Canvas(root,width=cw,height=ch,background='#ffffff')
canvas.grid(row=0,column=0)
# ........................................................... Toolbar
toolbar=Frame(root)
toolbar.grid(row=0,column=1,sticky=N)
toolbar.columnconfigure(1,minsize=130)
# ............................................................ Buttons
nr=0
butts=[]
ButtLab=['Start','Time Reversal','Exit']
ButtComm=[StartStop,TimeReversal,StopAll]
for i,(ll,cc) in enumerate(zip(ButtLab,ButtComm)):
  butts.append(Button(toolbar,text=ll,command=cc,width=11))
  butts[i].grid(row=nr,column=0,sticky=W)
  nr+=1
# ...................................................... Time Entries
TimeEntry=[]
TimeTxt=['Time Step/s','\u03C4/ms']
tVal=[dt,cycle]
tfor=['{:.2f}','{:d}']
for i,tt in enumerate(TimeTxt):
  lb=Label(toolbar,text=tt,font=('Helvetica',11))
  lb.grid(row=nr,column=0)
  TimeEntry.append(Entry(toolbar,bd=5,width=11))
  TimeEntry[i].grid(row=nr,column=1)
  TimeEntry[i].insert(0,tfor[i].format(tVal[i]))
  TimeEntry[i].bind('<Return>',lambda event,num=i:ReadData(None,num))
  nr+=1
# ..................................................... Energy Labels
EnLab=[]
for i,mm in enumerate(method+['Initial']):
  lab=Label(toolbar,text=mm+' Energy',font=('Helvetica',11))
  lab.grid(row=nr,column=0)
  EnLab.append(Label(toolbar,text='     ',font=('Helvetica',11)))
  EnLab[i].grid(row=nr,column=1)
  EnLab[i].config(fg=col[i])
  nr+=1
# ...................................................... Other Labels
QtLab=[]
for i,qt in enumerate(quant):
  lab=Label(toolbar,text=qt,font=('Helvetica',11))
  lab.grid(row=nr,column=0)
  QtLab.append(Label(toolbar,text='0',font=('Helvetica',11)))
  QtLab[i].grid(row=nr,column=1)
  nr+=1
QtLab[SCALE]['text']='{:10.3e}'.format(scale)
# .................................... Draw Coordinate Axes and Earth
canvas.create_line(0,Oy,cw,Oy,fill='black')
canvas.create_line(Ox,0,Ox,ch,fill='black')
canvas.create_oval(Ox-6,Oy-8,Ox+8,Oy+8,fill='#50a0ff',outline='#50a0ff')
# ............................ Create Satellite Image for Each Method
Image,Trail,ScaledTrail,ImTrail,AlgoInput=([] for _ in range(5))
for i,mm in enumerate(method):
  Image.append(canvas.create_oval(Ox+scale*x0-rad,Oy-scale*y0+rad,\
    Ox+scale*x0+rad,Oy-scale*y0-rad,fill=col[i],outline=col[i]))
  Trail.append([x0,y0]*TrailLength)
  ScaledTrail.append(Trail[i][:])
  ScaledTrail[i][::2]=[Ox+scale*zz for zz in ScaledTrail[i][::2]]
  ScaledTrail[i][1::2]=[Oy-scale*zz for zz in ScaledTrail[i][1::2]]
  ImTrail.append(canvas.create_line(ScaledTrail[i],fill=col[i]))
  AlgoInput.append(np.array([x0,y0,vx0,vy0]))
ypair=[y0,y0]
# ........................ Initial Accelerations for Verlet Algorithm
VerlAccel=accel([x0,y0])
# .................................................... Initial Energy
en=ener([x0,y0,vx0,vy0])
for el in EnLab:
  el.config(text='{:.6e}'.format(en))
# ........................................................ Initialize
tt0=time.time()
tcount=nIter=nOrbits=0
Telaps=0.0
# .................................................... Animation Loop
while RunAll:
  StartIter=time.time()
  # ........................................... Draw Satellite Images
  for i,ima in enumerate(Image):
    xx,yy=CanvCoord(AlgoInput[i])
    canvas.coords(ima,xx-rad,yy+rad,xx+rad,yy-rad)
    canvas.coords(ImTrail[i],ScaledTrail[i])
  # .......................................................... update
  canvas.update()
  # .......................................................... motion
  if RunMotion:
    # ........................................... update elapsed time
    Telaps+=dT
    # ......................................... Move Satellite Images
    for i,(al,st) in enumerate(zip(algo,ScaledTrail)):
      AlgoInput[i]=al(AlgoInput[i],step[i])
      # ............................................... Update Trails
      xx,yy=CanvCoord(AlgoInput[i])
      if np.linalg.norm([xx-st[-2],yy-st[-1]])>10:
        del Trail[i][:2]
        Trail[i].append(AlgoInput[i][0])
        Trail[i].append(AlgoInput[i][1])
        ScaledTrail[i][::2]=[Ox+scale*zz for zz in Trail[i][::2]]
        ScaledTrail[i][1::2]=[Oy-scale*zz for zz in Trail[i][1::2]]
    # .................................................. Count Orbits
    ypair[0]=ypair[1]
    ypair[1]=AlgoInput[YOSHIDA][1]
    if ypair[0]<0 and ypair[1]>0:
      nOrbits+=1
      QtLab[ORBITS].config(text=str(nOrbits))
      OrbPeriod=Telaps/nOrbits
      QtLab[PERIOD].config(text=str(timedelta(seconds=int(OrbPeriod))))
    # ........................................ show iteration counter
    nIter+=1
    if nIter%50==0:
      QtLab[ELAPSED].config(text=str(timedelta(seconds=int(Telaps))))
      QtLab[ITER].config(text=str(nIter))
      for i,ai in enumerate(AlgoInput):
        en=ener(ai)
        EnLab[i].config(text='{:.6e}'.format(en))
  elif GetStep:
    # .................................... Restart with New Time Step
    try:
      dt=float(TimeEntry[0].get())
    except ValueError:
      pass
    dT=3*dt
    step=[dt,dt,dT,dT]
    TimeEntry[0].delete(0,END)
    TimeEntry[0].insert(0,'{:.2f}'.format(dt))
    # ................................... Resets Positions and Trails
    for i,mm in enumerate(method):
      AlgoInput[i]=np.array([x0,vx0,y0,vy0])
      Trail[i]=[x0,y0]*TrailLength
      ScaledTrail[i][::2]=[Ox+scale*zz for zz in Trail[i][::2]]
      ScaledTrail[i][1::2]=[Oy-scale*zz for zz in Trail[i][1::2]]
    ypair=[y0,y0]
    VerlAccel=accel([x0,y0])
    nOrbits=0
    Telaps=0.0
    GetStep=FALSE
  elif GetCycle:
    try:
      cycle=int(TimeEntry[1].get())
    except ValueError:
      pass
    TimeEntry[1].delete(0,END)
    TimeEntry[1].insert(0,'{:d}'.format(cycle))
    GetCycle=FALSE
  # ................................................ Cycle Duration
  tcount+=1
  if tcount>=10:
    tcount=0
    ttt=time.time()
    elapsed=ttt-tt0
    QtLab[CYCLE]['text']='{:8.3f}'.format(elapsed*100.0)+' ms'
    tt0=ttt
  ElapsIter=int((time.time()-StartIter)*1000.0)
  canvas.after(cycle-ElapsIter)
#----------------------------------------------------------------------
root.destroy()
  
