# Simulator for CHP 
#
#  Depends on ply
#
# 0.1 - Stephen Longfield, 11 Dec, 2012
#       - Initial creation
# 0.2 - Stephen Longfield, 10 June 2014
#       - Addign a bunch of stuff, fixing a few things

import chp_parse
from chp_funs import *
import copy
import sys

# Class to hold the CHP and related functions
class CHP:
  def __init__(self, ast):
    '''Initialze internal structures'''
    self.ast = ast
    self.state = {}

  def pretty_print(self, ast):
    '''Actual function for pretty-printing an AST'''
    if ast[0] == 'NEW':
      return ("new:" + ast[1] + ".(" +
                self.pretty_print(ast[2]) + ")")
    elif ast[0] == 'SEQ':
      me = ""
      for s in ast[1]:
        me += self.pretty_print(s)
        me += "; "
      me = me[:-2]
      return me
    elif ast[0] == 'PAR':
      me = ""
      for s in ast[1]:
        me += self.pretty_print(s) + "\n||  "
      me = me[:-4]
      return me
    elif ast[0] == 'REP':
      return ( "*[" + self.pretty_print(ast[1]) + "]")
    elif ast[0] == 'SELECT_ONE':
      return ("[" + self.pretty_print(ast[1][0]) +
              " -> " + self.pretty_print(ast[1][1]) + 
              "]")
    elif ast[0] == 'SELECT_DET':
      me = "["
      for g in ast[1][1]:
        me += self.pretty_print(g[1]) + " -> "
        me += self.pretty_print(g[2]) + " [] "
      me = me[:-4]
      me += "]"
      return me
    elif ast[0] == 'SELECT_UDET':
      me = "["
      for g in ast[1][1]:
        me += self.pretty_print(g[1]) + " -> "
        me += self.pretty_print(g[2]) + " [] "
      me = me[:-4]
      me += "]"
      return me
    elif ast[0] == 'COMM':
      return self.pretty_print(ast[1])
    elif ast[0] == 'SEND': 
      return str(ast[1]) + "!" + self.pretty_print(ast[2])
    elif ast[0] == 'SEND`': 
      return str(ast[1]) + "`!" + self.pretty_print(ast[2])
    elif ast[0] == 'RECV':
      return str(ast[1]) + "?" + self.pretty_print(ast[2])
    elif ast[0] == 'RECV`':
      return str(ast[1]) + "`?" + self.pretty_print(ast[2])
    elif ast[0] == 'SEND_C': 
      return str(ast[1]) + "!!"
    elif ast[0] == 'SEND_C`': 
      return str(ast[1]) + "`!!"
    elif ast[0] == 'RECV_C':
      return str(ast[1]) + "??"
    elif ast[0] == 'RECV_C`':
      return str(ast[1]) + "`??" 
    elif ast[0] == 'GUARD_p':
      return '#' + ast[1]
    elif ast[0] == 'GUARD_e':
      return self.pretty_print(ast[1])
    elif ast == 'ELSE':
      return "else"
    elif ast[0] == 'ASSIGN':
      return ast[1] + ":=" + self.pretty_print(ast[2])
    elif ast[0] == 'EXPR':
      if len(ast) == 2:
        return self.pretty_print(ast[1])
      else:
        return self.pretty_print(ast[1]) + "&&" + self.pretty_print(ast[2])
    elif ast[0] == 'CONJ':
      if len(ast) == 2:
        return self.pretty_print(ast[1])
      else:
        return self.pretty_print(ast[1]) + "||" + self.pretty_print(ast[2])
    elif ast[0] == 'PRIM':
      if len(ast) == 2:
        if ast[1] == True:
          return 'True'
        elif ast[1] == False:
          return 'False'
        return ast[1]
      elif len(ast) == 3:
        return "~" + self.pretty_print(ast[2])
    elif ast[0] == 'PRIM_p':
      return "(" + self.pretty_print(ast[1]) + ")"
    elif ast == 'SKIP':
      return "skip"
    elif ast == 'TRUE':
      return "true"
    elif ast == 'FALSE':
      return "false"
    elif ast[0] == 'INT':
      return str(ast[1])
    elif ast[0][0:4] == 'PLUS':
      return self.pretty_print(ast[1]) + "+" + self.pretty_print(ast[2])
    elif ast[0] == 'BEQ':
      return self.pretty_print(ast[1]) + "=" + self.pretty_print(ast[2])
    else:
      return ast

  def __str__(self):
    '''Pretty-print the AST and state'''
    if self.ast is not None:
      me = self.pretty_print(self.ast)
    else:
      me = ""

    if len(self.state.keys()) > 0:
      me += "\n\t" + str(self.state)

    return me

  def eval_expr(self, expr):
    if expr[0] == 'GUARD_e':
      return self.eval_expr(expr[1])
    elif expr[0] == 'GUARD_p':
      return self.state[expr[1]]
    elif expr[0] == 'GUARD':
      return (self.state[expr[1]] and self.eval_expr(expr[2]))
    elif expr[0] == 'EXPR':
      if len(expr) == 2:
        return self.eval_expr(expr[1])
      else:
        return (self.eval_expr(expr[1]) and self.eval_expr(expr[2]))
    elif expr[0] == 'CONJ':
      if len(expr) == 2:
        return self.eval_expr(expr[1])
      else:
        return (self.eval_expr(expr[1]) or self.eval_expr(expr[2]))
    elif expr[0] == 'PRIM':
      if len(expr) == 2:
        if expr[1] in [True, False]:
          expr[1]
        else:
          return self.state[expr[1]]
      else:
        return (not self.eval_expr(expr[2]))
    elif expr[0] == 'PRIM_p':
      return self.eval_expr(expr[1])
    elif expr[0] == 'BEQ':
      return (self.eval_iexpr(expr[1]) == self.eval_iexpr(expr[2]))

  def eval_iexpr(self, expr):
    if expr[0] == "INT":
      return int(expr[1])
    elif expr[0] == "PLUS_ID0":
      return self.state[expr[1]] + self.eval_expr(expr[2])
    elif expr[0] == "PLUS_ID1":
      return self.eval_expr(expr[1]) + self.state[expr[2]]
    elif expr[0] == "PLUS_ID":
      return self.state[expr[1]] + self.state[expr[2]]
    elif expr[0] == "PLUS":
      return self.eval_expr(expr[1]) + self.eval_expr(expr[2])
    else:
      return self.state[expr]

  def step(self, ast):
    '''If any of the reduction rules apply, apply them. Return True if
       a rule was succesfully applied, false otherwise.'''
    if ast == None:
      return (None, False)
    if ast[0] == 'NEW':
      self.state[ast[1]] = False
      return (ast[2], True)
    if ast[0] == 'PAR':
      a1, step = self.step(ast[1][ast[2]])
      ast2 = (ast[2] + 1) % len(ast[1])
      if a1 == None:
        ast1 = copy.copy(ast[1])
        ast1.pop(ast[2])
        ast2 = (ast[2] + 1) % len(ast1)
        if len(ast1) == 1:
    			return (ast1[0], True)
        else:
          return (ast, True)
      if step:
        ast1 = copy.copy(ast[1])
        ast1[ast[2]] = a1
        ast2 = (ast[2] + 1) % len(ast1)
        return (('PAR', ast1, ast2), True)
      else:
        return (('PAR', ast[1], ast2), True)
      return (('PAR', ast[1], ast2), True)
    if ast[0] == 'REP':
      return (('SEQ', [ast[1], ast]), True)
    if ast[0] == 'SEQ':
      (a1, step) = self.step(ast[1][0])
      if a1 is not None:
        return (('SEQ', [a1] + ast[1][1:]), step)
      else:
        if len(ast[1][1:]) == 1:
          return (ast[1][1], step)
        else:
          return (('SEQ', ast[1][1:]), step)
    if ast[0] == 'COMM':
      return self.step(ast[1])
    if ast[0] == 'SEND':
      if ast[1] not in self.state.keys():
        return (ast, False)
      if self.state[ast[1]] == True:
        return (ast, False)
      if self.state[ast[1]] == False:
        self.state[ast[1]] = True
        return (('SEND`', ast[1], ast[2]), True)
    if ast[0] == 'SEND`':
      if self.state[ast[1]] == True:
        return (ast, False)
      else:
        if ast[1][0] == "EXPR":
          self.state[self.state[ast[1]]] = self.eval_expr(ast[2])
        else:
          self.state[self.state[ast[1]]] = self.eval_iexpr(ast[2])
        self.state[ast[1]] = "DONE"
        return ('SKIP', True)
    if ast[0] == 'RECV':
      if self.state[ast[1]] == True:
        self.state[ast[1]] = ast[2]
        return (('RECV`', ast[1], ast[2]), True)
    if ast[0] == 'RECV`':
      if self.state[ast[1]] == "DONE":
        self.state[ast[1]] = False
        return ('SKIP', True)
    if ast[0] == 'SEND_C':
      if ast[1] not in self.state.keys():
        return (ast, False)
      if self.state[ast[1]] == True:
        return (ast, False)
      if self.state[ast[1]] == False:
        self.state[ast[1]] = True
        return (('SEND_C`', ast[1]), True)
    if ast[0] == 'SEND_C`':
      if self.state[ast[1]] == True:
        return (ast, False)
      else:
        self.state[ast[1]] = "DONE"
        return ('SKIP', True)
    if ast[0] == 'RECV_C':
      if ast[1] in self.state:
        if self.state[ast[1]] == True:
          self.state[ast[1]] = "C"
          return (('RECV_C`', ast[1]), True)
      else:
        self.state[ast[1]] = False
    if ast[0] == 'RECV_C`':
      if self.state[ast[1]] == "DONE":
        self.state[ast[1]] = False
        return ('SKIP', True)
    if ast == 'SKIP':
      return (None, True)
    if ast[0] == "SELECT_ONE":
      if self.eval_expr(ast[1][1]):
        return (ast[1][2], True)
      else:
        return (ast, False)
    if ast[0] in ['SELECT_DET', 'SELECT_UDET']: 
      # Temporarily doing the same thing for determined and undetermined choice
      # Iterate through all of the guarded commands, if any of them are true or 'else', 
      #  return the clause.
      for g in ast[1][1]:
        if self.eval_expr(g[1]):
          return (g[2], True)
      return (ast, False)

      ### Old code for 2-only selection
      # print ast
      # if ast[1] == 'ELSE' and ast[2] == 'ELSE':
      #   print "ERROR -- TWO ELSE STATEMENTS"
      # elif ast[1] == 'ELSE':
      #  if self.eval_expr(ast[3]):
      #    return (ast[4], True)
      #  else:
      #    return (ast[2], True)
      # elif ast[3] == 'ELSE':
      #   print "else", ast[1]
      #   if self.eval_expr(ast[1]):
      #     return (ast[2], True)
      #   else:
      #     return (ast[4], True)
      # if self.eval_expr(ast[1]):
      #   return (ast[2], True)
      # elif self.eval_expr(ast[3]):
      #   return (ast[4], True)
    if ast[0] == 'ASSIGN':
      self.state[ast[1]] = self.eval_expr(ast[2])
      return ('SKIP', True)
    return (ast, False)

# Testing
if __name__ == "__main__":
  if len(sys.argv) != 2:
    print sys.argv
    print "Usage: python chp_sim.py input.chp"
    quit()
  else:
    f = open(sys.argv[-1], 'r')
    if f == None:
      print "File not found:", sys.argv[-1]
      quit()

  chp_tree = chp_parse.parser.parse(f.read())

  chp = CHP(chp_tree)

  print chp
  raw_input()
  (chp.ast, a) = newRename(chp.ast)
  print chp

  run = True
  while run:
    (chp.ast, run) = chp.step(chp.ast)
    print chp
    raw_input()
