#!/usr/bin/env python3
#coding: utf8
import time
import numpy as np
from tkinter import *
from scipy.integrate import odeint
from scipy.optimize import fsolve
# .................................................. Global Variables
RunAll=True
DrawTrails=GetData=RunMotion=False
# ....................................................... Canvas data
ButtWidth=9
cw=800
ch=640
Ox=cw/2
Oy=ch/4
scale=50
# ........................................ Gravitational Acceleration
g=9.8
dt=0.02
prad=3
tau=20
TrailLength=400
# ........................................................ pendulum 1
L1=4.0
k1=500.0
rad1=12
m1=10.0
eta1=0.0
theta1=0.0
pColor1='red'
# ........................................................ pendulum 2
L2=3.5
k2=500.0
rad2=12
m2=10.0
eta2=0.0
theta2=0.25*np.pi
pColor2='blue'
#.................................................... initial values
x1=0.8*L1*np.cos(theta1)
y1=0.8*L1*np.sin(theta1)
x2=x1+0.8*L2*np.cos(theta2)
y2=y1+0.8*L2*np.sin(theta2)
vx1=vy1=vx2=vy2=0.0
# ......................................................... StartStop
def StartStop():
  global RunMotion
  RunMotion=not RunMotion
  if RunMotion:
    butt[0]['text']='Stop'
    for ve in butt[2:]+VarEntry:
      ve['state']=DISABLED
  else:
    butt[0]['text']='Restart'
    for ve in butt[2:]+VarEntry:
      ve['state']=NORMAL
# ..................................................... Trails On/Off
def SwitchTrails():
  global DrawTrails
  DrawTrails=not DrawTrails
  if DrawTrails:
    butt[1]['text']='Hide Trails'
    canvas.itemconfig(TrailImg1,fill=pColor1)
    canvas.itemconfig(TrailImg2,fill=pColor2)
  else:
    butt[1]['text']='Show Trails'
    canvas.itemconfig(TrailImg1,fill='white')
    canvas.itemconfig(TrailImg2,fill='white')
# ........................................................... StopAll
def StopAll():
  global RunAll
  RunAll=False
# .......................................................... ReadData
def ReadData(*arg):
  global GetData
  GetData=True
# ............................................................... cvx
def cvx(x):
  global Ox,scale
  return Ox+scale*x
# ............................................................... cvy
def cvy(x):
  global Oy,scale
  return Oy-scale*x
# ........................................................... CateFun
def CateFun(x,CatePar):
  L,dx,dy=CatePar
  rr=np.sqrt(L**2-dy**2)/dx
  return rr-np.sinh(x)/x
# ............................................................ Energy
def ener(rv):
  x1,y1,x2,y2,vx1,vy1,vx2,vy2=rv
  pot=m1*g*y1+m2*g*y2
  r1=np.sqrt(x1**2+y1**2)
  if r1>L1:
    pot+=0.5*k1*(r1-L1)**2
  r2=np.sqrt((x2-x1)**2+(y2-y1)**2)
  if r2>L2:
    pot+=0.5*k2*(r2-L2)**2
  return pot+0.5*(m1*(vx1**2+vy1**2)+m2*(vx2**2+vy2**2))
# .......................................................... catenary
def catenary(x1,y1,x2,y2,L):
  if x1>x2:
    x1,x2=x2,x1
    y1,y2=y2,y1
  dx=x2-x1
  dy=y2-y1
  dist=np.sqrt(dx**2+dy**2)
  if dist>=L:
    return [cvx(x1),cvy(y1),cvx(x2),cvy(y2)]
  xav=(x1+x2)/2.0
  yav=(y1+y2)/2.0
  band=[cvx(x1),cvy(y1)]
  if abs(dx*scale)<4:
    band.extend([cvx(xav),cvy(yav-0.5*L)])
  else:
    CatePar=[L,dx,dy]
    AA0=0.001
    AA=fsolve(CateFun,AA0,CatePar)[0]
    aa=abs(0.5*dx/AA)
    bb=xav-aa*np.arctanh(dy/L)
    cc=yav-0.5*L/np.tanh(AA)
    for i in range(1,10):
      xx=x1+dx*i/10.0
      band.append(cvx(xx))
      band.append(cvy(aa*np.cosh((xx-bb)/aa)+cc))
  band.extend([cvx(x2),cvy(y2)])
  return band
# .......................................................... function
def dfdt(rv,t):
  x1,y1,x2,y2,vx1,vy1,vx2,vy2=rv
  dist1=np.sqrt(x1**2+y1**2)
  dist2=np.sqrt((x2-x1)**2+(y2-y1)**2)
  dL1=dist1-L1
  dL2=dist2-L2
  stheta1=y1/dist1
  ctheta1=x1/dist1
  stheta2=(y2-y1)/dist2
  ctheta2=(x2-x1)/dist2
  f1=0.0
  if dL1>0:
    f1=-dL1*k1
  f2=0.0
  if dL2>0:
    f2=-dL2*k2
  ax1=(f1*ctheta1-f2*ctheta2-eta1*vx1)/m1
  ay1=(f1*stheta1-f2*stheta2-m1*g-eta1*vy1)/m1
  ax2=(f2*ctheta2-eta2*vx2)/m2
  ay2=(f2*stheta2-m2*g-eta2*vy2)/m2
  return [vx1,vy1,vx2,vy2,ax1,ay1,ax2,ay2]
# ........................................................ Initialize
rv=[x1,y1,x2,y2,vx1,vy1,vx2,vy2]
trail1=[x1,y1]*TrailLength
trail2=[x2,y2]*TrailLength
ScaledTrail1=[cvx(x1),cvy(y1)]*TrailLength
ScaledTrail2=[cvx(x2),cvy(y2)]*TrailLength
# ................................................ Create Root Window
root=Tk()
root.title('Double Elastic Band Pendulum')
root.bind('<Return>',ReadData)
# ........................................................... canvas
canvas=Canvas(root,width=cw,height=ch,background='#ffffff')
canvas.grid(row=0,column=0)
# ........................................................... toolbar
toolbar=Frame(root)
toolbar.grid(row=0,column=1,sticky=N)
toolbar.option_add('*Font','Helvetica 11')
# ........................................................... buttons
nr=0
butt=[]
ButtLab=['Start','Show Trails','Exit']
ButtComm=[StartStop,SwitchTrails,StopAll]
for i,(ll,cc) in enumerate(zip(ButtLab,ButtComm)):
  butt.append(Button(toolbar,text=ll,command=cc,width=11))
  butt[i].grid(row=nr,column=0,sticky=W)
  nr+=1
# ........................................................... Entries
VarEntry=[]
VarLab=['x\u2081','y\u2081','x\u2082','y\u2082','vx\u2081',\
  'vy\u2081','vx\u2082','vy\u2082','L\u2081','k\u2081',\
    'm\u2081','\u03B7\u2081','L\u2082','k\u2082','m\u2082',\
      '\u03B7\u2082','dt','\u03C4']
inputs=[rv[0],rv[1],rv[2],rv[3],rv[4],rv[5],rv[6],rv[7],L1,k1,\
  m1,eta1,L2,k2,m2,eta2,dt,tau]
nVar=len(VarLab)
for i,lab in enumerate(VarLab):
  Label(toolbar,text=str(lab)).grid(row=nr,column=0)
  VarEntry.append(Entry(toolbar,bd=5,width=ButtWidth))
  VarEntry[i].grid(row=nr,column=1)
  VarEntry[i].insert(0,'{:.3f}'.format(inputs[i]))
  nr+=1
# ............................................................ Labels
LabList=['Period','Initial Energy','Energy','Iterations']
PERIOD,ENER0,ENER1,ITER=range(4)
Lab=[]
for i,ll in enumerate(LabList):
  Label(toolbar,text=ll,).grid(row=nr,column=0)
  Lab.append(Label(toolbar,text='     '))
  Lab[i].grid(row=nr,column=1,sticky=W)
  nr+=1
# .............................................. Create Canvas Images
canvas.create_line(0,Oy,cw,Oy,fill='green')
PivImg=canvas.create_oval(Ox-prad,Oy-prad,Ox+prad,Oy+prad,fill='black')
BandImg1=canvas.create_line(Ox,Oy,cvx(x1),cvy(y1),fill='black')
BandImg2=canvas.create_line(cvx(x1),cvy(y1),cvx(x2),cvy(y2),\
  fill='black')
TrailImg1=canvas.create_line(ScaledTrail1,fill=pColor1)
TrailImg2=canvas.create_line(ScaledTrail2,fill=pColor2)
BobImg1=canvas.create_oval(cvx(x1)-rad1,cvy(y1)-rad1,cvx(x1)+rad1,\
  cvy(y1)+rad1,fill=pColor1)
BobImg2=canvas.create_oval(cvx(x2)-rad2,cvy(y2)-rad2,cvx(x2)+rad2,\
  cvy(y2)+rad2,fill=pColor2)
# ........................................... numerical time interval
t=[0.0,dt]
tcount=0
nIter=0
tt0=time.time()
en=ener(rv)
Lab[ENER0].config(text='{:.8e}'.format(en))
Lab[ENER1].config(text='{:.8e}'.format(en))
# .................................................... Animation Loop
while RunAll:
  StartIter=time.time()
  # ..................................... Draw Bands, Bobs and Trails
  canvas.coords(BandImg1,catenary(0.0,0.0,x1,y1,L1))
  canvas.coords(BandImg2,catenary(x1,y1,x2,y2,L2))
  if DrawTrails:
    canvas.coords(TrailImg1,ScaledTrail1)
    canvas.coords(TrailImg2,ScaledTrail2)
  canvas.coords(BobImg1,cvx(x1)-rad1,cvy(y1)-rad1,cvx(x1)+rad1,\
    cvy(y1)+rad1)
  canvas.coords(BobImg2,cvx(x2)-rad2,cvy(y2)-rad2,cvx(x2)+rad2,\
    cvy(y2)+rad2)
  canvas.update()
  if RunMotion:
    nIter+=1
    # ......................................... Update Bob Positions
    psoln=odeint(dfdt,rv,t)
    rv=psoln[1]
    x1,y1,x2,y2,vx1,vy1,vx2,vy2=rv
    # ................................................ Update Trails
    if scale*np.linalg.norm([rv[0]-trail1[-2],rv[1]-trail1[-1]])>10:
      trail1=trail1[2:]
      ScaledTrail1=ScaledTrail1[2:]
      trail1.extend(rv[:2])
      ScaledTrail1.extend([cvx(rv[0]),cvy(rv[1])])
    if ((rv[2]-trail2[-2])**2+(rv[3]-trail2[-1])**2)>100/(scale**2):
      trail2=trail2[2:]
      ScaledTrail2=ScaledTrail2[2:]
      trail2.extend(rv[2:4])
      ScaledTrail2.extend([cvx(rv[2]),cvy(rv[3])])
    # .................................... Update Iterations Counter
    if nIter%20==0:
      en=ener(rv)
      Lab[ENER1].config(text='{:.8e}'.format(en))
      Lab[ITER].config(text='{:d}'.format(nIter))
  # .................................................... Read Entries
  elif GetData:
    for i,ve in enumerate(VarEntry):
      try:
        inputs[i]=float(ve.get())
      except ValueError:
        pass
      ve.delete(0,END)
      ve.insert(0,'{:.3f}'.format(inputs[i]))
    rv=inputs[:8]
    L1,k1,m1,eta1,L2,k2,m2,eta2,dt,tau=inputs[8:]
    x1,y1,x2,y2,vx1,vy1,vx2,vy2=rv
    tau=int(tau)
    t=[0.0,dt]
    en=ener(rv)
    Lab[ENER0].config(text='{:.8e}'.format(en))
    Lab[ENER1].config(text='{:.8e}'.format(en))
    trail1=[rv[0],rv[1]]*TrailLength
    ScaledTrail1=[cvx(rv[0]),cvy(rv[1])]*TrailLength
    trail2=[rv[2],rv[3]]*TrailLength
    ScaledTrail2=[cvx(rv[2]),cvy(rv[3])]*TrailLength
    GetData=False
  # .................................................. Cycle Duration
  tcount+=1
  if tcount%20==0:
    ttt=time.time()
    elapsed=ttt-tt0
    Lab[PERIOD]['text']='%8.3f'%(elapsed*50.0)+' ms'
    tt0=ttt
  # .................................................................
  ElapsIter=int((time.time()-StartIter)*1000.0)
  canvas.after(tau-ElapsIter)
  #-------------------------------------------------------------------
root.destroy()
root.mainloop()
  
