#!/usr/bin/env python3
from tkinter import *
from numpy import arctan2,cos,sin,sqrt
import time
# .................................................. Global variables
RunAll=True
GetData=RunIter=False
ButtWidth=9
# ......................................... ............ Canvas sizes
cw=800
ch=600
# ...................................................... Start values
tau=20  #  milliseconds
m1=200
r1=40
x1=r1
y1=r1
vx1=5.0
vy1=5.0
m2=150
r2=30
x2=cw-r2
y2=r2
vx2=-5.0
vy2=5.0
# ........................................................ Class ball
class ball:
  def __init__(self,mass,radius,x,y,vx,vy,color):
    self.m=mass
    self.rad=radius
    self.x=x
    self.y=y
    self.vx=vx
    self.vy=vy
    self.col=color
    self.image=canvas.create_oval(self.x-self.rad,ch-(self.y+\
      self.rad),self.x+self.rad,ch-(self.y-self.rad),\
        fill=self.col,outline=self.col)
    # ..................................................... Move ball
  def move(self):
    self.x+=self.vx
    self.y+=self.vy
    canvas.coords(self.image,self.x-self.rad,ch-(self.y+self.rad),\
      self.x+self.rad,ch-(self.y-self.rad))
    # ...................................... Bounce on canvas borders
  def bounce(self):
    if (self.x+self.rad)>=cw:
      self.vx=-abs(self.vx)
      self.x=2.0*(cw-self.rad)-self.x
    if (self.x-self.rad)<=0:
      self.vx=abs(self.vx)
      self.x=2.0*self.rad-self.x
    if (self.y+self.rad)>=ch:
      self.vy=-abs(self.vy)
      self.y=2.0*(ch-self.rad)-self.y
    if (self.y-self.rad)<=0:
      self.vy=abs(self.vy)
      self.y=2.0*self.rad-self.y
      # ........................................... Elastic collision
  def ElastColl(self,other):
    X=other.x-self.x
    Y=other.y-self.y
    distsq=X**2+Y**2
    R12sq=(self.rad+other.rad)**2
    if distsq<=R12sq:
      tc=0.0
      # .................................... Adjust overlapping balls
      if distsq<R12sq:
        Xdot=other.vx-self.vx
        Ydot=other.vy-self.vy
        aa=Xdot**2+Ydot**2
        bbhalf=X*Xdot+Y*Ydot
        cc=X**2+Y**2-R12sq
        # ....................... Time elapsed since "real" collision
        tc=(-bbhalf-sqrt(bbhalf**2-aa*cc))/aa
        # ........................ Time reversal to collision instant
        other.x+=tc*other.vx
        other.y+=tc*other.vy
        self.x+=tc*self.vx
        self.y+=tc*self.vy
        # ............................ Distances at collision instant
        X=other.x-self.x
        Y=other.y-self.y
      # ................................... Collision reference frame
      alpha=arctan2(Y,X)
      csalpha=cos(alpha)
      snalpha=sin(alpha)
      SelfVelXi=self.vx*csalpha+self.vy*snalpha
      SelfVelEta=-self.vx*snalpha+self.vy*csalpha
      OtherVelXi=other.vx*csalpha+other.vy*snalpha
      OtherVelEta=-other.vx*snalpha+other.vy*csalpha
      SelfNewVelXi=((self.m-other.m)*SelfVelXi+2.0*other.m*OtherVelXi)/\
        (self.m+other.m)
      OtherNewXi=((other.m-self.m)*OtherVelXi+2.0*self.m*SelfVelXi)/\
        (self.m+other.m)
      self.vx=SelfNewVelXi*csalpha-SelfVelEta*snalpha
      self.vy=SelfNewVelXi*snalpha+SelfVelEta*csalpha
      other.vx=OtherNewXi*csalpha-OtherVelEta*snalpha
      other.vy=OtherNewXi*snalpha+OtherVelEta*csalpha
      # ...........................................................
      other.x-=tc*other.vx
      other.y-=tc*other.vy
      self.x-=tc*self.vx
      self.y-=tc*self.vy
# .................................................. Button functions
def ReadData(*arg):
  global GetData
  GetData=True
#
def StartStop():
  global RunIter
  RunIter=not RunIter
  if RunIter:
    StartButton["text"]="Stop"
  else:
    StartButton["text"]="Restart"
#
def StopAll():
  global RunAll
  RunAll=False
# ................................................ Create root window
root=Tk()
root.title('Class Collide')
root.bind('<Return>',ReadData)
# ...................................................................
canvas=Canvas(root,width=cw,height=ch,background="#ffffff")
canvas.grid(row=0,column=0)
# .................................................... Create toolbar
toolbar=Frame(root)
toolbar.grid(row=0,column=1, sticky=N)
# ........................................................... Buttons
nr=0
StartButton=Button(toolbar,text="Start",command=StartStop,
                   width=ButtWidth)
StartButton.grid(row=nr,column=0,sticky=W)
nr+=1
CloseButton=Button(toolbar, text="Exit", command=StopAll,
                   width=ButtWidth)
CloseButton.grid(row=nr,column=0,sticky=W)
nr+=1
# .................................................. Parameter arrays
LabPar=[]
EntryPar=[]
ParList=['m\u2081','r\u2081','vx\u2081','vy\u2081','m\u2082','r\u2082',
         'vx\u2082','vy\u2082','\u03C4']
nPar=len(ParList)
# .................................. Entries for new parameter values
for i in range (nPar):
  LabPar.append(Label(toolbar,text=str(ParList[i]),
                      font=("Helvetica",12)))
  LabPar[i].grid(row=nr,column=0)
  EntryPar.append(Entry(toolbar,bd=5,width=10))
  EntryPar[i].grid(row=nr,column=1)
  nr+=1
# ........................................................ Time label
CycleLab0=Label(toolbar,text="Period:",font=("Helvetica",11))
CycleLab0.grid(row=nr,column=0)
CycleLab=Label(toolbar,text="     ",font=("Helvetica",11))
CycleLab.grid(row=nr,column=1,sticky=W)
nr+=1
# ........................................................ Parameters
params=[m1,r1,vx1,vy1,m2,r2,vx2,vy2,tau]
for i in range(nPar):
  buff="%.2f" % params[i]
  EntryPar[i].delete(0,'end')
  EntryPar[i].insert(0,buff)
# ............................................ Create colliding balls
ball1=ball(m1,r1,x1,y1,vx1,vy1,"red")
ball2=ball(m2,r2,x2,y2,vx2,vy2,"blue")
# ....................................................... Time origin
tt0=time.time()
tcount=0
# .................................................... Animation loop
while RunAll:
  StartIter=time.time()
  # ...................................................... Move balls
  if RunIter:
    ball1.move()
    ball1.bounce()
    ball2.move()
    ball2.bounce()
    ball1.ElastColl(ball2)
  else:
    if GetData:
      i=0
      while i<nPar:
        try:
          params[i]=float(EntryPar[i].get())
        except ValueError:
          pass
        i+=1
      ball1.m,ball1.rad,ball1.vx,ball1.vy,\
        ball2.m,ball2.rad,ball2.vx,ball2.vy,tau=params
      tau=int(tau)
      for i in range(nPar):
        buff="%.2f" % params[i]
        EntryPar[i].delete(0,'end')
        EntryPar[i].insert(0,buff)
      ball1.x=ball1.rad-ball1.vx
      ball1.y=ball1.rad-ball1.vy
      ball2.x=cw-ball2.rad-ball2.vx
      ball2.y=ball2.rad-ball2.vy
      ball1.move()
      ball2.move()
      GetData=False
  # ................................................ Cycle duration
  tcount+=1
  if tcount==10:
    tcount=0
    ttt=time.time()
    elapsed=ttt-tt0
    CycleLab['text']="%8.2f"%(elapsed*100.0)+" ms"
    tt0=ttt
  ElapsIter=int((time.time()-StartIter)*1000.0)
  canvas.update()
  canvas.after(tau-ElapsIter)
  #------------------------------------------------------------------
root.destroy()
  
