#!/usr/bin/env python3
import os
from numpy import *
from tkinter import *
from random import *
from scipy.integrate import odeint
import time

# ............................................. subscripts for labels
sub=['\u2080','\u2081','\u2082','\u2083','\u2084','\u2085',\
  '\u2086','\u2087','\u2088','\u2089']
# .................................................. Global variables
RunAll=True
RunIter=False
NewBaryc=False
GetData=False
ReWrite=False
scale=1.0
# ..................................................... Class CelBody
class CelBody:
  def __init__(self,mass,radius,frict,x,y,vx,vy,color):
    self.m=mass
    self.rad=radius
    self.fr=frict
    self.x=x
    self.y=y
    self.vx=vx
    self.vy=vy
    self.col=color
    self.wakex=[self.x]
    self.wakey=[self.y]
    self.collide=0
  # ..................................................... move body
  def move(self):
    self.x+=self.vx
    self.y+=self.vy
  # ......................................................... blast
  def blast(self,canvas,orig):
    global scale
    if self.collide<=0:
      return
    xx=scale*self.x
    yy=scale*self.y
    canvas.create_oval(orig[0]+int(xx-100),\
      int(orig[1]-yy+100),int(orig[0]+xx+100),\
        int(orig[1]-yy-100),outline='yellow',fill='yellow')
    self.collide-=1
  # ............................................ draw body and wake
  def draw(self,canvas,orig,nWake):
    global scale
    xx=scale*self.x
    yy=scale*self.y
    rr=scale*self.rad
    if rr<4.0:
      rr=4.0
    canvas.create_oval(orig[0]+int(xx-rr),\
      int(orig[1]-yy+rr),int(orig[0]+xx+rr),\
        int(orig[1]-yy-rr),outline=self.col,fill=self.col)
    k=1
    while k<len(self.wakex):
      canvas.create_line(orig[0]+scale*self.wakex[k-1],\
        orig[1]-scale*self.wakey[k-1],orig[0]+scale*self.wakex[k],\
          orig[1]-scale*self.wakey[k],fill=self.col)
      k+=1
    self.wakex.append(self.x)
    self.wakey.append(self.y)
    while len(self.wakex)>nWake:
      pip=self.wakex.pop(0)
      pip=self.wakey.pop(0)
# .................................................... create bodies
bd=[]
bd.append(CelBody(200.0,10.0,0.01,0.0,0.0,0.0,0.0,'red'))
bd.append(CelBody(10.0,6.0,0.01,300.0,0.0,0.0,40.0,'blue'))
bd.append(CelBody(5.0,4.0,0.01,0.0,200.0,50.0,0.0,'green'))
nB=len(bd)
# ............................................................ masses
masses=[]
frictions=[]
i=0
while i<nB:
  masses.append(bd[i].m)
  frictions.append(bd[i].fr)
  i+=1
# ............................................ finite time difference
dt=0.05
# ............................................ gravitational constant
KG=5000.0
# ................................................. cycle duration/ms
cycle=20
# ....................................................... tail length
nWake=200
# ................................................. barycenter radius
bcrad=2
# ........................................................ value list
values=[KG,dt,cycle,nWake]
t=[0.0,dt]
# ................................................... parameter array
params=[values,masses,frictions]
# ........................................... evaluate center of mass
def baryc(bd):
  nB=len(bd)
  cx=cy=cvx=cvy=mtot=0
  i=0
  while i<nB:
    cx+=bd[i].m*bd[i].x
    cy+=bd[i].m*bd[i].y
    cvx+=bd[i].m*bd[i].vx
    cvy+=bd[i].m*bd[i].vy
    mtot+=bd[i].m
    i+=1
  cx/=mtot
  cy/=mtot
  cvx/=mtot
  cvy/=mtot
  return [[cx,cy],[cvx,cvy]]
# ..................................... move origin to center of mass
def SetBaryc():
  global NewBaryc
  global bd
  global nWake
  nB=len(bd)
  xcm,ycm=baryc(bd)[0]
  cvx,cvy=baryc(bd)[1]
  i=0
  while i<nB:
    bd[i].x-=xcm
    bd[i].y-=ycm
    bd[i].vx-=cvx
    bd[i].vy-=cvy
    """
    bd[i].wakex=[bd[i].x]*nWake
    bd[i].wakey=[bd[i].y]*nWake
    """
    i+=1
  NewBaryc=True
# ......................................................... bodies2vect
#   x1,y1,x2,y2, ...,vx1,vy1,vx2,vy2, ...=vect
def bodies2vect(bodies,vect,nB):
  twonB=nB*2
  i=0
  while i<nB:
    vect[2*i]=bodies[i].x
    vect[2*i+1]=bodies[i].y
    vect[twonB+2*i]=bodies[i].vx
    vect[twonB+2*i+1]=bodies[i].vy
    i+=1
# ......................................................... vect2bodies
def vect2bodies(vect,bodies,nB):
  twonB=2*nB
  i=0
  while i<nB:
    bodies[i].x=vect[2*i]
    bodies[i].y=vect[2*i+1]
    bodies[i].vx=vect[twonB+2*i]
    bodies[i].vy=vect[twonB+2*i+1]
    i+=1
# ............................................... create input vector
def WriteInput(bodies,val,InpV):
  nB=len(bodies)
  nV=len(values)
  del InpV[:]
  i=0
  while i<nB:
    InpV.append(bodies[i].m)
    InpV.append(bodies[i].fr)
    InpV.append(bodies[i].x)
    InpV.append(bodies[i].y)
    InpV.append(bodies[i].vx)
    InpV.append(bodies[i].vy)
    i+=1
  i=0
  while i<nV:
    InpV.append(val[i])
    i+=1
# ....................................................................
def ReadInput(InpV,bodies,masses,frictions,val,vect):
  nB=len(bodies)
  nV=len(val)
  twonB=2*nB
  i=0
  while i<nB:
    bodies[i].m=masses[i]=InpV[6*i]
    bodies[i].fr=frictions[i]=InpV[6*i+1]
    bodies[i].x=vect[2*i]=InpV[6*i+2]
    bodies[i].y=vect[2*i+1]=InpV[6*i+3]
    bodies[i].vx=vect[twonB+2*i]=InpV[6*i+4]
    bodies[i].vy=vect[twonB+2*i+1]=InpV[6*i+5]
    i+=1
  j=0
  i=6*nB
  while j<nV:
    val[j]=InpV[i]
    i+=1
    j+=1
# ................................................ Elastics collision
def ElastColl(bd1,bd2):
  deltax=bd2.x-bd1.x
  deltay=bd2.y-bd1.y
  impactsq=(bd1.rad+bd2.rad)**2
  # .................................................... out of range
  if (deltax**2+deltay**2)>impactsq:
    return False
  else:
    if bd1.collide==0:
      bd1.collide=4
      bd2.collide=4
      alpha=arctan2(deltay,deltax)
      csalpha=cos(alpha)
      snalpha=sin(alpha)
      Xi1=bd1.vx*csalpha+bd1.vy*snalpha
      Eta1=-bd1.vx*snalpha+bd1.vy*csalpha
      Xi2=bd2.vx*csalpha+bd2.vy*snalpha
      Eta2=-bd2.vx*snalpha+bd2.vy*csalpha
      NewXi1=((bd1.m-bd2.m)*Xi1+2.0*bd2.m*Xi2)\
        /(bd1.m+bd2.m)
      NewXi2=((bd2.m-bd1.m)*Xi2+2.0*bd1.m*Xi1)\
        /(bd1.m+bd2.m)
      bd1.vx=NewXi1*csalpha-Eta1*snalpha
      bd1.vy=NewXi1*snalpha+Eta1*csalpha
      bd2.vx=NewXi2*csalpha-Eta2*snalpha
      bd2.vy=NewXi2*snalpha+Eta2*csalpha
      return True
    else:
      return False
# ............................................... Start/Stop function
def StartStop():
  global RunIter
  RunIter=not RunIter
  if RunIter:
    StartButton["text"]="Stop"
  else:
    StartButton["text"]="Restart"
# ..................................................... Exit function
def StopAll():
  global RunAll
  RunAll=False
# ................................................ Read Data function
def ReadData(*arg):
  global GetData
  GetData=True 
# .......................................................... Scale Up
def ScaleUp(*arg):
  global scale
  scale*=2.0
  ScaleLab['text']="%8.3f"%(scale)
# ........................................................ Scale Down
def ScaleDown(*arg):
  global scale
  scale/=2.0
  ScaleLab['text']="%8.3f"%(scale)
# ....................................................... root window
root=Tk()
root.title('Gravitation with friction')
root.bind('<Return>',ReadData)
root.bind('<Control-plus>',ScaleUp)
root.bind('<Control-minus>',ScaleDown)
# ............................................................ canvas
cw=800
ch=800
canvas=Canvas(root,width=cw,height=ch,background="#ffffff")
canvas.grid(row=0,column=0)
# ........................................................... toolbar
toolbar=Frame(root)
toolbar.grid(row=0,column=1, sticky=N)
# ............................................................ buttons
nr=0
StartButton=Button(toolbar,text="Start",command=StartStop,width=11)
StartButton.grid(row=nr,column=0,sticky=W)
nr+=1
AdjustButton=Button(toolbar, text="Set Barycenter",\
  command=SetBaryc,width=11)
AdjustButton.grid(row=nr,column=0,columnspan=2,sticky=W)
nr+=1
ReadButton=Button(toolbar, text="Read Data", command=ReadData,width=11)
ReadButton.grid(row=nr,column=0,columnspan=2,sticky=W)
nr+=1
CloseButton=Button(toolbar, text="Exit", command=StopAll,width=11)
CloseButton.grid(row=nr,column=0,sticky=W)
nr+=1
# ............................................................ Origin
Ox=cw/2
Oy=ch/2
orig=[Ox,Oy]
# ...................................................... input vector
InpV=[]
WriteInput(bd,values,InpV)
nInp=len(InpV)
# ..................................................... odeint vector
y=[0]*4*nB
bodies2vect(bd,y,nB)
# ........................................................ Input list
InputList=[]
i=0
while i<nB:
  InputList.append('m'+sub[i])
  InputList.append('\u03B7'+sub[i])
  InputList.append('x'+sub[i])
  InputList.append('y'+sub[i])
  InputList.append('vx'+sub[i])
  InputList.append('vy'+sub[i])
  i+=1
InputList.append('G')
InputList.append('dt')
InputList.append('Cycle/ms')
InputList.append('Tail')
# ................................................ Labels and Entries
LabInput=[]
EntryInput=[]
i=0
while i<nInp:
  LabInput.append(Label(toolbar,text=InputList[i],\
    font=("Helvetica",12)))
  LabInput[i].grid(row=nr,column=0)
  EntryInput.append(Entry(toolbar,bd=5,width=8))
  EntryInput[i].grid(row=nr,column=1)
  EntryInput[i].insert(0,"{:.3f}".format(InpV[i]))
  i+=1
  nr+=1
# ........................................................ time label
CycleLab0=Label(toolbar,text="Period:",font=("Helvetica",11))
CycleLab0.grid(row=nr,column=0)
CycleLab=Label(toolbar,text="     ",font=("Helvetica",11))
CycleLab.grid(row=nr,column=1,sticky=W)
nr+=1
# ....................................................... scale label
ScaleLab0=Label(toolbar,text="Scale:",font=("Helvetica",11))
ScaleLab0.grid(row=nr,column=0)
ScaleLab=Label(toolbar,text="%8.3f"%(scale),font=("Helvetica",11))
ScaleLab.grid(row=nr,column=1,sticky=W)
nr+=1
# .................................................................... function
def fun(yInp,t,params):
  KG,dt,cycle,nWake=params[0]             # unpack parameters
  mm=params[1]                            # unpack masses
  fric=params[2]
  nB=len(mm)
  twonB=2*nB
  twoN=2*nB
  fourN=4*nB
  ax=[0]*nB
  ay=[0]*nB
  i=1
  while i<nB:
    j=0
    while j<i:
      deltax=yInp[2*i]-yInp[2*j]
      deltay=yInp[2*i+1]-yInp[2*j+1]
      r2=deltax**2+deltay**2
      alpha=arctan2(deltay,deltax)
      ff=KG/r2
      fx=ff*cos(alpha)
      fy=ff*sin(alpha)
      ax[i]-=fx*mm[j]
      ax[j]+=fx*mm[i]
      ay[i]-=fy*mm[j]
      ay[j]+=fy*mm[i]
      j+=1
    i+=1
  derivs=[]
  # ............................................... insert velocities
  i=twoN
  while i<fourN:
    derivs.append(yInp[i])
    i+=1
  # ............................................ insert accelerations
  i=0
  while i<nB:
    derivs.append(ax[i]-fric[i]*yInp[twonB+2*i])
    derivs.append(ay[i]-fric[i]*yInp[twonB+2*i+1])
    i+=1
  return derivs
# ........................................... numerical time interval
t=[0.0,0.1]
tt0=time.time()
tcount=0
while RunAll:
  StartIter=time.time()
  # ............................................. Draw Cartesian axes
  canvas.delete(ALL)
  # ............................................. draw blasts, if any
  i=0
  while i<nB:
    bd[i].blast(canvas,orig)
    i+=1
  # ................................................ draw coordinates
  canvas.create_line(0,Oy,cw,Oy,fill="black")
  canvas.create_line(Ox,0,Ox,ch,fill="black")
  canvas.create_oval(Ox-bcrad,Oy-bcrad,Ox+bcrad,Oy+bcrad,fill="black")
  # ..................................................... draw bodies
  i=0
  while i<nB:
    bd[i].draw(canvas,orig,nWake)
    i+=1
  # .................................................. center of mass
  cx,cy=baryc(bd)[0]
  cx*=scale
  cy*=scale
  canvas.create_oval(Ox+cx-bcrad,Oy-cy-bcrad,Ox+cx+bcrad,Oy-cy+bcrad,\
    fill="black")
  # .......................................................... motion
  if RunIter:
    # ..................................................... next step
    psoln = odeint(fun,y,t,args=(params,))
    y=psoln[1,:]
    vect2bodies(y,bd,nB)
    # ............................................. check if colliding
    i=1
    j=0
    coll=False
    while i<nB:
      while j<i:
        coll=(coll or ElastColl(bd[i],bd[j]))
        j+=1
      j=0
      i+=1
    if coll:
      bodies2vect(bd,y,nB)
  # .................................................................
  else:
    if NewBaryc:
      ReWrite=True
      WriteInput(bd,values,InpV)
      NewBaryc=False
    elif GetData:
      i=0
      while i<nInp:
        try:
          InpV[i]=float(EntryInput[i].get())
        except ValueError:
          pass
        i+=1
      ReWrite=True
      GetData=False
    if ReWrite:
      ReadInput(InpV,bd,masses,frictions,values,y)
      i=0
      while i<nInp:
        EntryInput[i].delete(0,'end')
        EntryInput[i].insert(0,"{:.3f}".format(InpV[i]))
        i+=1
      t=[0.0,values[1]]
      cycle=int(values[2])
      nWake=int(values[3])
      i=0
      while i<nB:
        del bd[i].wakex[:]
        del bd[i].wakey[:]
        bd[i].wakex=[bd[i].x]*nWake
        bd[i].wakey=[bd[i].y]*nWake
        i+=1
      ReWrite=False
  # ................................................ cycle duration
  tcount+=1
  if tcount==10:
    tcount=0
    ttt=time.time()
    elapsed=ttt-tt0
    CycleLab['text']="%8.3f"%(elapsed*100.0)+" ms"
    tt0=ttt
  ElapsIter=int((time.time()-StartIter)*1000.0)
  canvas.update()
  canvas.after(cycle-ElapsIter)
#----------------------------------------------------------------------
root.destroy()
root.mainloop()
  