#!/usr/bin/env python3
from tkinter import *
import numpy as np
import time
from scipy.integrate import odeint
from scipy.optimize import fsolve
# .................................................. Global variables
RunAll=True
GetData=Grabbed=RunMotion=False
# ....................................................... Canvas data
ButtWidth=9
cw=800
ch=640
Ox=cw/2
Oy=ch/2
# ............................................... Physical parameters
g=9.8           # m/s**2
L=4.0           # m
m=5.0           # kg
k=500.0     # N/m
eta=0.0         # kg/s
dt=0.01         # s
# ................................................................
prad=3          #  pivot radius
rad=12          #  bob radius
bColor='red'    #  bob color
# ..............................................................
scale=50.0      # pixels/m
tau=20          # milliseconds
TrailLength=400
# ..................................... Initial position and velocity
rv=[1.1*L,0.0,0.0,0.0]  #   [x0,y0,vx0,vy0]
trail=[rv[0],0.0]*TrailLength
ScaledTrail=[Ox+scale*trail[0],Oy]*TrailLength
# ........................................................ Start/Stop
def StartStop():
  global RunMotion
  RunMotion=not RunMotion
  if RunMotion:
    StartButton['text']='Stop'
    for ee in [ExitButton]+VarEntry:
      ee['state']=DISABLED
  else:
    StartButton['text']='Restart'
    ExitButton['state']=NORMAL
    for ee in [ExitButton]+VarEntry:
      ee['state']=NORMAL
# ...................................................... Exit Program
def StopAll():
  global RunAll
  RunAll=False
# ...................................................... Read Entries      
def ReadData(*args):
  global GetData
  GetData=True
# ......................................................... Grab ball
def GrabBall(event):
  global Grabbed,rad,RunMotion,rv
  if not RunMotion:
    Grabbed=((cvx(rv[0])-event.x)**2+(cvy(rv[1])-event.y)**2)<rad**2
# ......................................................... Drag ball
def DragBall(event):
  global Grabbed,Ox,Oy,rad,scale,rv
  if Grabbed:
    rv[0]=(np.clip(event.x,rad,cw-rad)-Ox)/scale
    rv[1]=(Oy-np.clip(event.y,rad,ch-rad))/scale
# ...................................................... Release ball
def ReleaseBall(event):
  global Grabbed,Lab,rv,trail,ScaledTrail
  rv[2:]=[0.0,0.0]
  trail=rv[:2]*TrailLength
  ScaledTrail=[cvx(rv[0]),cvy(rv[1])]*TrailLength
  en=ener(rv)
  Lab[ENER0].config(text='{:.8e}'.format(en))
  Lab[ENER1].config(text='{:.8e}'.format(en))
  Grabbed=False
# ............................................................... cvx
def cvx(x):
  global Ox,scale
  return Ox+scale*x
# ............................................................... cvy
def cvy(x):
  global Oy,scale
  return Oy-scale*x
# ........................................................... CateFun
def CateFun(x,CatePar):
  L,cx,cy=CatePar
  rr=np.sqrt(L**2-cy**2)/cx
  return rr-np.sinh(x)/x
# .......................................................... catenary
def catenary(xy):
  global L
  r=np.linalg.norm(xy)
  if r>=L:
    return[Ox,Oy,cvx(xy[0]),cvy(xy[1])]
  band=[Ox,Oy]
  if abs(xy[0]*scale)<4:
    band.extend([cvx(0.5*xy[0]),cvy(0.5*(xy[1]-L))])
    band.extend([cvx(xy[0]),cvy(xy[1])])
  else:
    absx=abs(xy[0])
    CatePar=[L,absx,xy[1]]
    AA0=0.01
    AA=fsolve(CateFun,AA0,CatePar)[0]
    aa=0.5*absx/AA
    bb=0.5*absx-aa*np.arctanh(xy[1]/L)
    cc=0.5*(xy[1]-L/np.tanh(AA))
    for i in range(1,20):
      x1=xy[0]*i/20.0
      band.append(cvx(x1))
      band.append(cvy(aa*np.cosh((abs(x1)-bb)/aa)+cc))
    band.extend([cvx(xy[0]),cvy(xy[1])])
  return band
# ............................................................ Energy
def ener(rv):
  pot=m*g*rv[1]
  r=np.linalg.norm(rv[:2])
  if r>L:
    pot+=0.5*k*(r-L)**2
  return pot+0.5*m*np.dot(rv[2:],rv[2:])
# .................................... variable and parameter vectors
trail=[rv[0],rv[1]]*TrailLength
ScaledTrail=[cvx(rv[0]),cvy(rv[1])]*TrailLength
# .................................... derivatives-computing function
def dfdt(rv,t):
  global eta,g,k,L,m
  theta=np.arctan2(rv[1],rv[0])
  r=np.linalg.norm(rv[:2])
  stretch=r-L
  if stretch>0:
    force=-k*stretch
  else:
    force=0.0
  fx=force*np.cos(theta)-eta*rv[2]
  fy=force*np.sin(theta)-eta*rv[3]
  ax=(fx/m)
  ay=(fy/m)-g
  return [rv[2],rv[3],ax,ay]
# ................................................ Create root window
root=Tk()
root.title('Catenary Pendulum')
root.bind('<Return>',ReadData)
# ......................................... Add canvas to root window
canvas=Canvas(root,width=cw,height=ch,background='#ffffff')
canvas.grid(row=0,column=0)
# ...................................................... Mouse button
canvas.bind('<Button-1>',GrabBall)
canvas.bind('<B1-Motion>',DragBall)
canvas.bind('<ButtonRelease-1>',ReleaseBall)
# ........................................ Add toolbar to root window
toolbar=Frame(root)
toolbar.grid(row=0,column=1,sticky=N)
toolbar.option_add('*Font','Helvetica 11')
# ................................................... Toolbar buttons
nr=0
StartButton=Button(toolbar,text='Start',command=StartStop,\
  width=ButtWidth)
StartButton.grid(row=nr,column=0,sticky=W)
nr+=1
ExitButton=Button(toolbar,text='Exit',command=StopAll,width=ButtWidth)
ExitButton.grid(row=nr,column=0,sticky=W)
nr+=1
# ............................................ Label and Entry arrays
VarLab=['x\u2080','y\u2080','vx\u2080','vy\u2080','Length','k',\
  'Mass','\u03B7','scale','Time step','\u03C4/ms']
inputs=rv+[L,k,m,eta,scale,dt,tau]
VarEntry=[]
for i,lab in enumerate(VarLab):
  Label(toolbar,text=str(lab)).grid(row=nr,column=0)
  VarEntry.append(Entry(toolbar,bd=5,width=ButtWidth))
  VarEntry[i].grid(row=nr,column=1)
  VarEntry[i].insert(0,'{:.3f}'.format(inputs[i]))
  nr+=1
# ............................................................ Labels
LabList=['Period','Initial Energy','Energy','Iterations']
PERIOD,ENER0,ENER1,ITER=range(4)
Lab=[]
for i,ll in enumerate(LabList):
  Label(toolbar,text=ll,).grid(row=nr,column=0)
  Lab.append(Label(toolbar,text='     '))
  Lab[i].grid(row=nr,column=1,sticky=W)
  nr+=1
# ................................... Draw Circle and Horizontal Line
circle=canvas.create_oval(cvx(-L),cvy(L),cvx(L),cvy(-L),outline='green')
canvas.create_line(0,ch-Oy,cw,ch-Oy,fill='green')
# ..................................................... Draw Pendulum
canvas.create_oval(Ox-prad,Oy-prad,Ox+prad,Oy+prad,fill='black')
BandImg=canvas.create_line(Ox,Oy,cvx(rv[0]),cvy(rv[1]),fill='black')
BobImg=canvas.create_oval(cvx(rv[0])-rad,cvy(rv[1])-rad,\
  cvx(rv[0])+rad,cvy(rv[1])+rad,fill=bColor)
TrailImg=canvas.create_line(ScaledTrail,fill=bColor)
# ...................................................................
t=[0.0,dt]
tcount=0
nIter=0
tt0=time.time()
en=ener(rv)
Lab[ENER0].config(text='{:.8e}'.format(en))
Lab[ENER1].config(text='{:.8e}'.format(en))
# ......................................................... Main loop
while RunAll:
  StartIter=time.time()
  # ................................................... Draw pendulum
  canvas.coords(BandImg,catenary(rv[:2]))
  canvas.coords(TrailImg,ScaledTrail)
  canvas.coords(BobImg,cvx(rv[0])-rad,cvy(rv[1])-rad,\
    cvx(rv[0])+rad,cvy(rv[1])+rad)
  canvas.update()
  if RunMotion:
    nIter+=1
    # .......................... Velocity and position for next frame
    psoln=odeint(dfdt,rv,t)
    rv=psoln[1]
    if nIter%20==0:
      en=ener(rv)
      Lab[ENER1].config(text='{:.8e}'.format(en))
      Lab[ITER].config(text='{:d}'.format(nIter))
    # .................................................. Update Trail
    if ((rv[0]-trail[-2])**2+(rv[1]-trail[-1])**2)>100/(scale**2):
      trail=trail[2:]
      ScaledTrail=ScaledTrail[2:]
      trail.extend(rv[:2])
      ScaledTrail.extend([cvx(rv[0]),cvy(rv[1])])
  # .................................................... Read Entries
  elif GetData:
    for i,ve in enumerate(VarEntry):
      try:
        inputs[i]=float(ve.get())
      except ValueError:
        pass
      ve.delete(0,END)
      ve.insert(0,'{:.3f}'.format(inputs[i]))
    rv=inputs[:4]
    L,k,m,eta,scale,dt,tau=inputs[4:]
    tau=int(tau)
    t=[0.0,dt]
    en=ener(rv)
    Lab[ENER0].config(text='{:.8e}'.format(en))
    Lab[ENER1].config(text='{:.8e}'.format(en))
    trail=rv[:2]*TrailLength
    ScaledTrail=[cvx(rv[0]),cvy(rv[1])]*TrailLength
    canvas.coords(circle,cvx(-L),cvy(L),cvx(L),cvy(-L))
    GetData=False
  # .................................................. Cycle Duration
  tcount+=1
  if tcount%20==0:
    ttt=time.time()
    elapsed=ttt-tt0
    Lab[PERIOD]['text']='%8.3f'%(elapsed*50.0)+' ms'
    tt0=ttt
  # .................................................................
  ElapsIter=int((time.time()-StartIter)*1000.0)
  canvas.after(tau-ElapsIter)
  #------------------------------------------------------------------
root.destroy()
  
