import csv
import sys
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from collections import *
import os

########################################################################################
##Function creates two dicts storing the service numbers between which linking is#######
##############################required between up and down services#####################

def link():
  for k,v in link_upfile:
    link_uplist[k].append(v)
  for k,v in link_downfile:
    link_downlist[k].append(v)

########################################################################################
##Function creates two dictionaries with arrival and departure events at each terminal##
##station. These lists will be used along with the platform availability at these#######
##stations to decide linking of arrival and departure events.###########################
########################################################################################

def turnaround():
  for k,v in turnaround_arr:
    if v not in link_uplist.values() and v not in link_downlist.values():
      turnaround_arr_list[k].append(v)
  for k,v in turnaround_dep:
    if v not in link_uplist.values() and v not in link_downlist.values():
      turnaround_dep_list[k].append(v)

#######################################################################################
##Function creates set containing arrival events of services between which dwell time##
##################################will be specified####################################
#######################################################################################

def dwell():
  for k,v in dwell_up:
    dwell_list_up[k].append(v)
  for k,v in dwell_down:
    dwell_list_down[k].append(v)  

#######################################################################################
##Function creates traversal constraints between two consecutive stations. Four dicts##
##are created to hold the arrival and departure time of up and down trains. The########
##traversal time is stored in dict called odtrav#######################################
#######################################################################################

def traversal(tau,tdu,tad,tdd):
  for k,v in tau:
    traversal_arr_up_list[k].append(v)
  for k,v in tdu:
    traversal_dep_up_list[k].append(v)
  for k,v in tad:
    traversal_arr_down_list[k].append(v)
  for k,v in tdd:
    traversal_dep_down_list[k].append(v) 
  
######################################################################################
##Function creates symmetry constraints among services plying between same origin-####
##destination pairs.Therfore lists containing keys among which symmetry/frequency#####
##constraints are to be defined are made by this function#############################
######################################################################################

def frequency(frequp,freqdown):
  for k,v in frequp:
    frequency_list_up[k].append(v)
  for x,y in freqdown:
    frequency_list_down[x].append(y)  
    
######################################################################################
##Function  creates the services between which headway constraints are to be defined##
##If services have same two consecutive stations headways are to be put between the###
##departure times of these services.##################################################
######################################################################################

def headway(infoup,infodown):
  for k,v in infoup:
    headway_list_up[k].append(v)
  for x,y in infodown:
    headway_list_down[x].append(y)  
    
#####################################################################################
##Function writes lines in up.csv file. Apart from this this function also creates###
##some lists using which lists of services for different types of constraints are####
##created in other functions such as headway,frequency etc.##########################
#####################################################################################

def upwrite(l,demand,sup,f):
  global a
  global d
  global cntu
  for i in range(demand):
    line=[]
    count=0
    for j in range(len(l)):
      if l[j]==0:
        line.append(',')
      elif l[j]==1 and count==0:
        line.append(',')
        count=count+1
      else:
        if count%2==0:
          prevst=([x for x in range(1,j) if l[x]==1])
          deparr=str((prevst[-1]+1)/2)+str(j/2+1)
          traversal_arr_up.append([deparr,a])
          if l[j+1]==1:
            dwell_up.append([str(j/2+1),a]) 
          nextst=([x for x in range(j+1,len(l)) if l[x]==1]) 
          if nextst==[]:
            turnaround_arr.append([str(j/2+1),a])
          item=str(a)+","
          a=a+1
          count=count+1
          line.append(item)
        else:
          nextst=([x for x in range(j+1,len(l)) if l[x]==1]) 
          if count==1:
            frequencyup.append([sup,d])
            turnaround_dep.append([str((j+1)/2),d]) 
            if f==1: 
              cntu=cntu+1
              link_upfile.append([cntu,d])
          deparr=str((j+1)/2)+str(nextst[0]/2+1)
          headwayinfoup.append([deparr,d]) 
          traversal_dep_up.append([deparr,d])          
          item=str(d)+","
          d=d+1
          count=count+1
          line.append(item)
    for li in line:
      up.write(li)
    up.write("\n") 
  
######################################################################################
##Function writes lines in down.csv file. Apart from this this function also creates##
##some lists using which lists of services for different types of constraints are#####
##created in other functions such as headway,frequency etc.###########################
######################################################################################


def downwrite(l,demand,sdown,f):
  global a
  global d
  global cntd
  for i in range(demand):
    line=[]
    count=0
    for j in range(len(l)):
      if l[j]==0:
        line.append(',')
      elif l[j]==1 and count==0:
        line.append(',')
        count=count+1
      else:
        if count%2==0:
          prevst=([x for x in range(1,j) if l[x]==1])
          deparr=str((prevst[-1]+1)/2)+str(j/2+1)
          traversal_arr_down.append([deparr,a])
          if l[j+1]==1:
            dwell_down.append([str(j/2+1),a])
          nextst=([x for x in range(j+1,len(l)) if l[x]==1])
          if nextst==[]:
            turnaround_arr.append([str((len(l)-j)/2),a]) 
            if f==1:
              cntd=cntd+1
              link_downfile.append([cntd,a])
          item=str(a)+","
          a=a+1
          count=count+1
          line.append(item)
        else:
          nextst=([x for x in range(j+1,len(l)) if l[x]==1])
          if count==1: 
            frequencydown.append([sdown,d])
            turnaround_dep.append([str((len(l)-j+1)/2),d])  
          deparr=str((len(l)-nextst[0])/2)+str((len(l)-j+1)/2)
          headwayinfodown.append([deparr,d]) 
          
          traversal_dep_down.append([deparr,d])
          item=str(d)+","
          d=d+1
          count=count+1
          line.append(item)
    for li in line:
      down.write(li)
    down.write("\n")
    
    
######################################################################
##Function creates the headers for the two files up.csv and down.csv##
######################################################################

def headers(station_name):
  headerup=[]
  headerdown=[]
  for x in range(len(station_name)):
    sa,sd=station_name[x]+"a,",station_name[x]+"d,"
    headerup.append(sa)
    headerup.append(sd)
  for item in headerup:
    up.write(item) 
  up.write("\n")  
  for x in range(len(station_name)-1,-1,-1):
    sa,sd=station_name[x]+"a,",station_name[x]+"d,"
    headerdown.append(sa)
    headerdown.append(sd)
  for item in headerdown:
    down.write(item)   
  down.write("\n")

########################################################################
##This function writes the keys in up.csv and down.csv files for those##
##services which use a part of the up tracks and down tracks both.######
########################################################################

def updown(l):
  ldown=[]
  lup=[]
  for x in range(len(l)-1):
    if l[x]>l[x+1]:
      ldown.append(l[x])
      ldown.append(l[x+1])
    if l[x]<l[x+1]:
      lup.append(l[x])
      lup.append(l[x+1])
  ldown=list(set(ldown))
  lup=list(set(lup))
  ldown.sort(reverse=True),lup.sort()
  up_or,up_dest,down_or,down_dest=lup[0],lup[-1],ldown[0],ldown[-1]
  return up_or,up_dest,down_or,down_dest

  
#####################################################################################  
##Function initiates creation of the entire key file. The graph function calls this##
##function with the graph and lists containing origin-destination lists and demands##
##This function finds paths between all the origin-destination pairs using inbuilt###
##shortest-path algorithm of NetworkX and calls the upwrite and downwrite functions##
########################for writing the keys#########################################
#####################################################################################

def keys(src,demand,des,gph,station_number):
  for ind in range(len(src)):
    flag=0
    if demand[ind]==0:
      continue
    stnumseq=nx.shortest_path(gph,src[ind],des[ind])
    if sorted(stnumseq,key=int)!=stnumseq and sorted(stnumseq,key=int,reverse=True)!=stnumseq:
      flag=1
      up_start,up_end,down_start,down_end=updown(stnumseq)
    if (src[ind] < des[ind] and flag==0) or flag==1:
      if flag==0:
        stnumseq=nx.shortest_path(gph,src[ind],des[ind])
      else:
        stnumseq=nx.shortest_path(gph,up_start,up_end)  
      odup=str(stnumseq[0])+str(stnumseq[-1])
   
      uplist=[]
      for x in range(len(station_number)):
        if station_number[x] in stnumseq and stnumseq.index(station_number[x])==len(stnumseq)-1:
          uplist.append(1)
          uplist.append(0)
        elif station_number[x] in stnumseq:
          uplist.append(1)
          uplist.append(1)
        else:
          uplist.append(0)
          uplist.append(0)
      upwrite(uplist,demand[ind],odup,flag)
      
    if (src[ind] > des[ind] and flag==0) or flag==1:
      if flag==0:
        stnumseq=nx.shortest_path(gph,src[ind],des[ind])
      else:
        stnumseq=nx.shortest_path(gph,down_start,down_end)
      downlist=[]
      oddown=str(stnumseq[0])+str(stnumseq[-1])
      for x in range(len(station_number)):
        if station_number[len(station_number)-1-x] in stnumseq and stnumseq.index(station_number[len(station_number)-1-x])==len(stnumseq)-1:
          downlist.append(1)
          downlist.append(0)
        elif station_number[len(station_number)-1-x] in stnumseq:
          downlist.append(1)
          downlist.append(1)
        else:
          downlist.append(0)
          downlist.append(0)
      downwrite(downlist,demand[ind],oddown,flag)  
  headway(headwayinfoup,headwayinfodown)  
  frequency(frequencyup,frequencydown) 
  traversal(traversal_arr_up,traversal_dep_up,traversal_arr_down,traversal_dep_down) 
  dwell() 
  link()
  turnaround()
    
    
########################################################################################  
##Function draws the network graph, the graph is that is drawn is currently undirected##
##it takes the edges from the edge file and the nodes from the stations.csv file and####
###################creates an appropriate graph. Tool used is NetworkX##################
########################################################################################
    
def graph():
  gup=nx.Graph()
  edge,od,stations=pd.read_csv('Edge.csv'),pd.read_csv('OD.csv'),pd.read_csv('Stations.csv')
  source,dest,dem=od["SourceSrNum"],od["DestinationSrNum"],od["Demand"]
  From,to,trav_time=edge["From"],edge["To"],edge["Traversal_time"]
  stnum,stname,junc,platform=stations["Number"],stations["Station_name"],stations["Junction"],stations["Turnaround"]
  for ind in range(len(stnum)):
    gup.add_node(stnum[ind])
    st_numplatform[stnum[ind]]=platform[ind]
  for ind in range(0,len(edge),2):
    gup.add_edge(From[ind],to[ind],weight=trav_time[ind])
    od=str(From[ind])+str(to[ind])
    odtrav[od]=trav_time[ind]
  nx.draw(gup)
  plt.savefig('Networkup.png')
  plt.hold() 
  headers(stname)
  keys(source,dem,dest,gup,stnum)
  return stname

####################################################################################  
######################Global variables and files are opened here####################
####################################################################################

a=100
d=1000
cntu=0
cntd=0
headway_list_up=defaultdict(list)
headwayinfoup=[]
headwayinfodown=[]
headway_list_down=defaultdict(list)

frequency_list_up=defaultdict(list)
frequency_list_down=defaultdict(list)
frequencyup=[]
frequencydown=[]

odtrav={}
traversal_arr_up=[]
traversal_dep_up=[]
traversal_arr_down=[]
traversal_dep_down=[]
traversal_arr_up_list=defaultdict(list)
traversal_dep_up_list=defaultdict(list)
traversal_arr_down_list=defaultdict(list)
traversal_dep_down_list=defaultdict(list)

dwell_up=[]
dwell_down=[]
dwell_list_up=defaultdict(list)
dwell_list_down=defaultdict(list)


link_upfile=[]
link_downfile=[]
link_uplist=defaultdict(list)
link_downlist=defaultdict(list)

turnaround_arr_list=defaultdict(list)
turnaround_dep_list=defaultdict(list)
turnaround_arr=[]
turnaround_dep=[]
st_numplatform={}


up=open('up.csv','w')
down=open('down.csv','w')  
station_list = graph()  
up.close()
down.close()
