#!/usr/bin/env python3
from tkinter import *
import numpy as np
import time
from scipy.integrate import odeint
# .................................................. Global variables
RunAll=True
GetData=RunIter=False
ButtWidth=10
scale=30000          # px/m
cw=ch=800            # px
cycle=20             # ms
Ox=Oy=cw/2
nTrail=600
# .................................................................
rad=5
scale=30000
r0=0.01                # m
# ................................................ Initial velocities
speed1=5.868e4         # m/s
#SpeedOpen=2.934e4      # m/s
SpeedOpen=4.500e4      # m/s
speed3=1.835e4         # m/s
speed5=3.8120e3        # m/s
speed7=1.275e3         # m/s
speed9=4.89e2           # m/s
speeds=[speed1,SpeedOpen,speed3,speed5,speed7,speed9]
# ............................................... Start/Stop function
def StartStop():
  global RunIter
  RunIter=not RunIter
  if RunIter:
    StartButton['text']='Stop'
  else:
    StartButton['text']='Restart'
# ..................................................... Exit function
def StopAll():
  global RunAll
  RunAll=False
# ......................................................... Read data
def ReadData(*arg):
  global GetData
  GetData=True
# ........................................................... Restart
def Restart():
  global canvas,Ox,Oy,x,y,vx,vy,xs,ys,funcs,trail,TrailImag
  #x=r=r0   
  x=r0
  vx=y=0.0
  vy=speeds[iRadio.get()]
  xs=x*scale
  ys=y*scale
  funcs=[x,vx,y,vy]
  trail=[Ox+xs,Oy-ys]*nTrail
  canvas.delete(TrailImag)
  TrailImag=canvas.create_line(trail,fill='red')
  SpeedLab['text']='{:.2e}'.format(vy)
# ....................................................... root window
root=Tk()
root.title('Logarithmic Potential Orbits')
root.bind('<Return>',ReadData)
# ................................................................
iRadio=IntVar()
iRadio.set(0)
# ...................................................................
canvas=Canvas(root,width=cw,height=ch,background="#ffffff")
canvas.grid(row=0,column=0)
# ........................................................... toolbar
toolbar=Frame(root)
toolbar.grid(row=0,column=1,sticky=N)
# ............................................................ buttons
nr=0
StartButton=Button(toolbar,text='Start',command=StartStop,\
  width=ButtWidth)
StartButton.grid(row=nr,column=0,sticky=N)
nr+=1
CloseButton=Button(toolbar, text='Exit', command=StopAll,\
  width=ButtWidth)
CloseButton.grid(row=nr,column=0,sticky=N)
nr+=1
# ................................... initial velocity for open orbit
vfreeLab=Label(toolbar,text='v open:',font=("Helvetica",12))
vfreeLab.grid(row=nr,column=0)
vfreeEntry=Entry(toolbar,bd=5,width=ButtWidth)
vfreeEntry.grid(row=nr,column=1)
vfreeEntry.insert(0,'{:.2e}'.format(speeds[1]))
nr+=1
# ............................................................ Period
PeriodLab=Label(toolbar,text='Cycle/ms',font=("Helvetica",12))
PeriodLab.grid(row=nr,column=0)
PeriodEntry=Entry(toolbar,bd=5,width=ButtWidth)
PeriodEntry.grid(row=nr,column=1)
PeriodEntry.insert(0,str(cycle))
nr+=1
# ...................................................... Radio buttons
RadioLabs=['Circular','Open','3-Lobes','5-Lobes','7-Lobes','9-Lobes']
for i,spd in enumerate(speeds):
  Radiobutton(toolbar,text=RadioLabs[i]+' Orbit',variable=iRadio,\
    value=i,command=Restart).grid(row=nr,column=0,sticky=W)
  nr+=1
# ............................................. show initial velocity
SpeedLab0=Label(toolbar,text="v\u2080:",font=("Helvetica",11))
SpeedLab0.grid(row=nr,column=0)
SpeedLab=Label(toolbar,text='{:.3e}'.format(speeds[iRadio.get()]),\
  font=("Helvetica",11))
SpeedLab.grid(row=nr,column=1,sticky=W)
nr+=1
# ........................................................ show speed
CurrSpeedLab0=Label(toolbar,text="Velocity:",font=("Helvetica",11))
CurrSpeedLab0.grid(row=nr,column=0)
CurrSpeedLab=Label(toolbar,text="     ",font=("Helvetica",11))
CurrSpeedLab.grid(row=nr,column=1,sticky=W)
nr+=1
# ....................................................... show period
CycleLab0=Label(toolbar,text="Period:",font=("Helvetica",10))
CycleLab0.grid(row=nr,column=0)
CycleLab=Label(toolbar,text="     ",font=("Helvetica",10))
CycleLab.grid(row=nr,column=1,sticky=W)
nr+=1
# .......................................................... function
def dfdt(FuncIn,t):
  x,vx,y,vy=FuncIn
  r=np.sqrt(x**2+y**2)
  alpha=np.arctan2(y,x)
  a=-K/r
  ax=a*np.cos(alpha)
  ay=a*np.sin(alpha)
  return [vx,ax,vy,ay]
# .................................................. numerical values
x=0.01             #  m
vx=y=0.0
vy=speed1          #  m/s
K=3.443e9          #  m2/s2
# ...................................................... input lists
t=[0.0,2.0e-9]
funcs=[x,vx,y,vy]
xs=x*scale
ys=y*scale
trail=[Ox+xs,Oy-ys]*nTrail
TrailImag=canvas.create_line(trail,fill='red')
# ....................................................... first frame
proton=canvas.create_oval(Ox+xs+rad,Oy-ys+rad,Ox+xs-rad,Oy-ys-rad,\
  fill='red')
canvas.create_oval(Ox+2,Oy-2,Ox-2,Oy+2,fill='black')
canvas.create_line(0,Oy,cw,Oy,fill='green')
canvas.create_line(Ox,0,Ox,ch,fill='green')
# ..................................................................
tt0=time.time()
iter=0
tcount=0
increase=False
# .................................................... animation loop
while RunAll:
  StartIter=time.time()
  canvas.coords(proton,Ox+xs+rad,Oy-ys+rad,Ox+xs-rad,Oy-ys-rad)
  canvas.coords(TrailImag,trail)
  canvas.update()
  if RunIter:
    psoln=odeint(dfdt,funcs,t)
    funcs=psoln[1,:]
    x,vx,y,vy=funcs
    xs=x*scale
    ys=y*scale
    tcount+=1
    if tcount%3==0:
      del trail[:2]
      trail.append(Ox+xs)
      trail.append(Oy-ys)
    if tcount%5==0:
      speed=np.sqrt(vx**2+vy**2)
      CurrSpeedLab['text']='{:.3e}'.format(speed)
    if tcount==10:
      tcount=0
      ttt=time.time()
      elapsed=ttt-tt0
      CycleLab['text']="%8.2f"%(elapsed*100.0)+" ms"
      tt0=ttt
  elif GetData:
    try:
      SpeedOpen=float(vfreeEntry.get())
    except ValueError:
      pass
    try:
      cycle=int(PeriodEntry.get())
    except ValueError:
      pass
    speeds[1]=SpeedOpen
    vfreeEntry.delete(0,'end')
    vfreeEntry.insert(0,'{:.2e}'.format(SpeedOpen))
    PeriodEntry.delete(0,'end')
    PeriodEntry.insert(0,'{:d}'.format(cycle))
    GetData=False
  # ................................................ cycle duration
  ElapsIter=int((time.time()-StartIter)*1000.0)
  canvas.after(cycle-ElapsIter)
root.destroy()

