# CHP functions
#
# 0.1 - Stephen Longfield, 25 Jun 2014
#       - defined collectProbed
# 0.2 - Stephen Longfield, 30 Jun 2014
#       - adding in newRename
# 0.3 - Stephen Longfield, July and Aug 2014
#       - Adding in more analyses to help with the CHP to Promela translation
#       -  collectProbed
#       -  collectShared

import copy

def newRenameExpr(expr, alias):
	'''Renames in an expression.'''
	# Presently, channels can't appear in expressions, so just return it
	return (expr, alias)

def newRenameComm(comm, alias):
	'''Renames in a communication statement.'''
	if comm[0] == "SEND":
		if comm[1] in alias:
			a = alias[comm[1]]
			if a[1] == 0:
				return (comm, alias)
			else:
				n = a[0] + str(a[1])
				return (("SEND", n, comm[2]), alias)
		else:
			alias[comm[1]] == (comm[1], 0)
			return (comm, alias)
	elif comm[0] == "RECV":
		if comm[1] in alias:
			a = alias[comm[1]]
			if a[1] == 0:
				return (comm, alias)
			else:
				n = a[0] + str(a[1])
				return (("RECV", n, comm[2]), alias)
		else:
			alias[comm[1]] == (comm[1], 0)
			return (comm, alias)
	elif comm[0] == "SEND_C":
		if comm[1] in alias:
			a = alias[comm[1]]
			if a[1] == 0:
				return (comm, alias)
			else:
				n = a[0] + str(a[1])
				return (("SEND_C", n), alias)
		else:
			alias[comm[1]] == (comm[1], 0)
			return (comm, alias)
	elif comm[0] == "RECV_C":
		if comm[1] in alias:
			a = alias[comm[1]]
			if a[1] == 0:
				return (comm, alias)
			else:
				n = a[0] + str(a[1])
				return (("RECV_C", n), alias)
		else:
			alias[comm[1]] == (comm[1], 0)
			return (comm, alias)
	else:
		print "Unsupported COMM type ", comm[0]
		raise NotImplementedError("Unsupported COMM type")

def newRenameGuard(g, alias):
	'''Renames a guard'''
	if g == "ELSE":
		return (g, alias)
	elif g[0] == "GUARD_e":
		return (g, alias)
	elif g[0] == "GUARD_p":
		if g[1] in alias:
			a = alias[g[1]]
			if a[1] == 0:
				return (g, alias)
			else:
				n = a[0] + str(a[1])
				return ((g[0], n), alias)
		else:
			alias[g[1]] = (g[1], 0)
			return (g, alias)
	elif g[0] == "GUARD":
		(g2, alias) = newRenameGuard(g[2], alias)
		if g[1] in alias:
			a = alias[g[1]]
			if a[1] == 0:
				return ((g[0], g[1], g2), alias)
			else:
				n = a[0] + str(a[1])
				return ((g[0], n, g2), alias)
		else:
			alias[g[1]] = (g[1], 0)
			return ((g[0], g[1], g2), alias)

def newRenameGC(gc, alias):
	'''Renames in a guarded command'''
	assert gc[0] == "GC"
	(g, alias) = newRenameGuard(gc[1], alias)
	(p, alias) = newRename(gc[2], alias)
	return (("GC", g, p), alias)

def newRename(program, alias={}):
	'''Takes in a program, looks through all of the new scoping rules, and if
  there are duplicates, renames them so they are not duplicated.'''

	if program[0] == "NEW":
		if program[1] in alias:
			a = alias[program[1]]
			alias[program[1]] = (a[0], a[1]+1)
		else:
			alias[program[1]] = (program[1], 0)
		# Do the renaming
		(p, alias) = newRename(program[2], alias)
		return (("NEW", program[1], p), alias)
	elif program[0] == "REP":
		(p, alias) = newRename(program[1], alias)
		return (("REP", p), alias)
	elif program[0] in ["PAR", "SEQ"]:
		ps = []
		for p in program[1]:
			(pp, alias) = newRename(p)
			ps.append(pp)
		if program[0] == "PAR":
			return ((program[0], ps, program[2]), alias)
		else:
			return ((program[0], ps), alias)
	elif program[0] == "SELECT_ONE":
		(p, alias) = newRenameGC(program[1], alias)
	elif program[0] in ["SELECT_DET", "SELECT_UDET"]:
		gcs = program[1]
		g = []
		for gc in gcs[1]:
			(gp, alias) = newRenameGC(gc, alias)
			g.append(gp)
		return ((program[0], (gcs[0], g)), alias)
	elif program[0] == "COMM":
		(p, alias) = newRenameComm(program[1], alias)
		return ((program[0], p), alias)
	elif program == "SKIP":
		return (program, alias)
	else:
		print "Not implmeneted: ", program[0]
		raise NotImplementedError("Tried to rename an unknown program type")

def collectChansGuard(guard, chans):
	'''Collects up the probed channels'''
	if guard[0] == "GUARD_e":
		return chans
	elif guard == "ELSE":
		return chans
	elif guard[0] == "GUARD_p":
		return ([guard[1]] + chans)
	else:
		assert guard[0] == "GUARD"
		chans = collectChansGuard(guard[2])
		return ([guard[1]] + chans)

def collectChansRec(program, chans):
	'''Recursive step for collecting the send/recieved channels.'''
	if program[0] == "NEW":
		chans = collectChansRec(program[2], chans)
	elif program[0] == "REP":
		chans = collectChansRec(program[1], chans)
	elif program[0] in ["SEQ", "PAR"]:
		for p in program[1]:
			chans = collectChansRec(p, chans)
	elif program[0] == "SELECT_ONE":
		gc = program[1]
		chans = collectChansRec(gc[2], chans)
	elif program[0] in ["SELECT_DET", "SELECT_UDET"]:
		gcs = program[1][1]
		for gc in gcs:
			chans = collectChansGuard(gc[1], chans)
			chans = collectChansRec(gc[2], chans)
	elif program[0] == "COMM":
		chans.append(program[1][1])
	elif program == "SKIP":
		pass
	else:
		raise NotImplementedError("Unexpected Program: " + str(program))
	return chans

def collectChans(program):
	'''Collects together the channels that are sent/recieved on'''
	return collectChansRec(program, [])

def collectSharedComm(comm, shared, send, recv, sp, rp):
	'''Checks a communication action against the send/recieve sets, and either updates shared
  or updates sp/rp (send/recieve prime)'''
	if comm[0] in ["SEND", "SEND_C"]:
		if comm[1] in send:
			shared.add(comm[1])
		sp.add(comm[1])
	elif comm[0] in ["RECV", "RECV_C"]:
		if comm[1] in recv:
			shared.add(comm[1])
		rp.add(comm[1])
	else:
		raise NotImplementedError("Unsupported COMM type" + comm[0])
	return (shared, sp, rp)

def collectSharedRec(program, shared, send, recv, sp, rp):
	if program[0] == "NEW":
		(shared, sp, rp) = collectSharedRec(program[2], shared, send, recv, sp, rp)
	elif program[0] == "REP":
		(shared, sp, rp) = collectSharedRec(program[1], shared, send, recv, sp, rp)
	elif program[0] == "SEQ":
		for p in program[1]:
			(shared, sp, rp) = collectSharedRec(p, shared, send, recv, sp, rp)
	elif program[0] == "PAR":
		# Here's where we do the update using sp/rp -- update sets for send and recieve
		for p in program[1]:
			(shared, sp, rp) = collectSharedRec(p, shared, sp, rp, copy.copy(send), copy.copy(recv))
	elif program[0] == "SELECT_ONE":
		(shared, sp, rp) = collectSharedRec(program[1][2], shared, send, recv, sp, rp)
	elif program[0] in ["SELECT_DET", "SELECT_UDET"]:
		gcs = program[1][1]
		# Ignoring shared probes
		for gc in gcs:
			(shared, sp, rp) = collectSharedRec(gc[2], shared, send, recv, sp, rp)
	elif program[0] == "COMM":
		(shared, sp, rp) = collectSharedComm(program[1], shared, send, recv, sp, rp)
	elif program == "SKIP":
		pass
	else:
		raise NotImplementedError("Unexpected Program: " + str(program))
	return (shared, sp, rp)

def collectShared(program):
	'''Collects together the channels that are shared between processes'''
	(shared, sp, rp) = collectSharedRec(program, set([]), set([]), set([]), set([]), set([]))
	return shared

def collectProbedGuard(guard, probed):
	if guard[0] == "GUARD_e":
		pass
	elif guard[0] == "GUARD_p":
		probed.append(guard[1])
		return probed
	elif guard[0] == "GUARD":
		probed.append(guard[1])
		probed = collectProbedGuard(guard[2])
	else: 
		NotImplementedError("Unexpected guard type: " + str(guard[0]))
	return probed

def collectProbedRec(program, probed):
	'''Recursive step for collecting the probed channels'''
	if program[0] == "NEW_PAR":
		probed = collectProbedRec(program[2], probed)
		probed = collectProbedRec(program[3], probed)
	elif program[0] == "NEW":
		probed = collectProbedRec(program[2], probed)
	elif program[0] == "REP":
		probed = collectProbedRec(program[1], probed)
	elif program[0] in ["SEQ", "PAR"]:
		for p in program[1]:
			probed = collectProbedRec(p, probed)
	elif program[0] == "SELECT_ONE":
		gc = program[1]
		assert gc[0] == "GC"
		probed = collectProbedGuard(gc[1], probed)
		probed = collectProbedRec(gc[2], probed)
	elif program[0] in ["SELECT_DET", "SELECT_UDET"]:
		for gc in program[1][1]:
			assert gc[0] == "GC"
			probed = collectProbedGuard(gc[1], probed)
			probed = collectProbedRec(gc[2], probed)
	elif program[0] == "COMM":
		pass
	elif program  == "SKIP":
		pass
	else:
		raise NotImplementedError("Unexpected Program: " + str(program))
	return probed

def collectProbed(program):
	'''Collects the channels that are probed in a program'''
	return collectProbedRec(program, [])

C_SEND, C_RECV, C_INTL, C_INT, C_DATA, C_BOOL, C_CTRL = range(7)

def chanCommType(comm, ctype):
	'''Update a ctype with a communication action.'''
	if comm[0] == "RECV_C":
		if comm[1] in ctype:
			if ctype[comm[1]][0] in [C_RECV, C_INTL]:
				pass
			elif ctype[comm[1]][1] != C_CTRL:
				raise Exception("Receive on non-control channel")
			else:
				ctype[comm[1]] = (C_INTL, C_CTRL)
		else:
			ctype[comm[1]] = (C_RECV, C_CTRL)
	elif comm[0] == "SEND_C":
		if comm[1] in ctype:
			if ctype[comm[1]][0] in [C_SEND, C_INTL]:
				pass
			elif ctype[comm[1]][1] != C_CTRL:
				raise Exception("Receive on non-control channel")
			else:
				ctype[comm[1]] = (C_INTL, C_CTRL)
		else:
			ctype[comm[1]] = (C_SEND, C_CTRL)
	elif comm[0] == "SEND":
		if comm[2][0] in ["INT", "PLUS", "PLUS_ID1", "PLUS_ID0", "PLUS_ID"]:
			type = C_INT
		elif comm[2][0] == "EXPR":
			type = C_BOOL
		else:
			type = C_DATA
		if comm[1] in ctype:
			if ctype[comm[1]][0] in [C_SEND]:
				pass
			else:
				ctype[comm[1]] = (C_INTL, type)
		else:
			ctype[comm[1]] = (C_SEND, type)
	elif comm[0] == "RECV":
		if comm[1] in ctype:
			if ctype[comm[1]][0] in [C_SEND, C_INTL]:
				pass
			else:
				ctype[comm[1]] = (C_INTL, ctype[comm[1]][1])
		else:
			# Unknown data type from recieve
			ctype[comm[1]] = (C_RECV, C_DATA)
	else:
		raise NotImplementedError("Unexpected communication: " + str(comm))
	return ctype

def channelTypeRec(program, ctype):
	'''Recursive step for channel typing. Not tail-recursive.'''
	if program[0] == "NEW":
		chan = program[1]
		ctype = channelTypeRec(program[2], ctype)
		if chan in ctype:
			ctype[chan] = (C_INTL, ctype[chan][1])
		return ctype
	elif program[0] in ["REP", "SELECT_ONE"]:
		return channelTypeRec(program[1], ctype)
	elif program[0] in ["SEQ", "PAR"]:
		for p in program[1]:
			ctype = channelTypeRec(p, ctype)
		return ctype
	elif program[0] == "GC":
		return channelTypeRec(program[2], ctype)
	elif program[0] in ["SELECT_DET", "SELECT_UDET"]:
		for gc in program[1][1]:
			ctype = channelTypeRec(gc, ctype)
		return ctype
	elif program[0] == "COMM":
		return chanCommType(program[1], ctype)
	elif program == "SKIP":
		return ctype
	else:
		raise NotImplementedError("Unexpected Program: " + str(program))

def channelType(program):
	'''Identifies each channel as either a send-only channel (CSEND), a recieve-only 
  channel (CRECV) or an internal channel (CINTL). Channels are classified as
  send if only a send occurss, recv if only a recieve occurs, and internal if both
  occur or the channel is nu reduced. Also indicates if a channel is data-carrying
  (C_DATA/C_INT/C_BOOL) or control (C_CTRL).'''
	return channelTypeRec(program, {})

V_DATA, V_BOOL, V_INT = range(3)

def varTypeExpr(expr, vtype, typ=V_DATA):
	'''Infers variables' type from their usage in expressions.'''
	if expr[0] == "EXPR":
		if len(expr) == 3:
			vtype = varTypeExpr(expr[1], vtype, V_BOOL)
			vtype = varTypeExpr(expr[2], vtype, V_BOOL)
		else:
			assert len(expr) == 2
			vtype = varTypeExpr(expr[1], vtype, V_BOOL)
	elif expr[0] == "CONJ":
		if len(expr) == 3:
			vtype = varTypeExpr(expr[1], vtype, V_BOOL)
			vtype = varTypeExpr(expr[2], vtype, V_BOOL)
		else:
			assert len(expr) == 2
			vtype = varTypeExpr(expr[1], vtype, V_BOOL)
	elif expr[0] == "PRIM":
		if expr[1] in [True, False]:
			pass
		elif expr[1] == "~":
			vtype = varTypeExpr(expr[2], vtype, V_BOOL)
		else:
			vtype = varTypeExpr(expr[1], vtype, typ)
	elif expr[0] == "PRIM_p":
		vtype = varTypeExpr(expr[1], vtype, typ)
	elif expr[0] in ["PLUS", "BEQ"]:
		vtype = varTypeExpr(expr[1], vtype, V_INT)
		vtype = varTypeExpr(expr[2], vtype, V_INT)
	elif expr[0] == "PLUS_ID1":
		vtype[expr[1]] = V_INT
		vtype = varTypeExpr(expr[2], vtype, V_INT)
	elif expr[0] == "PLUS_ID0":
		vtype[expr[2]] = V_INT
		vtype = varTypeExpr(expr[1], vtype, V_INT)
	elif expr[0] == "PLUS_ID":
		if isinstance(expr[1], str):
			vtype[expr[1]] = V_INT
		if isinstance(expr[2], str):
			vtype[expr[2]] = V_INT
	elif expr[0] == "INT":
		pass
	else:
		# ID
		assert isinstance(expr, str)
		vtype[expr] = typ
	return vtype

def varTypeComm(comm, ctype, vtype):
	'''Infers variables' type from their usage in communications.'''
	if comm[0] == "SEND":
		vtype = varTypeExpr(comm[2], vtype)
	elif comm[0] == "RECV":
		if ctype[comm[1]][1] == C_INT:
			vtype[comm[2]] = V_INT
		elif ctype[comm[1]][1] == C_BOOL:
			vtype[comm[2]] = V_BOOL
		else:
			raise NotImplementedError("Unexpected channel type on ", comm[1])
	elif comm[0] == "SEND_C":
		pass
	else:
		assert comm[0] == "RECV_C"
		pass
	return vtype

def varTypeRec(program, ctype, vtype):
	'''Infers variable's type from their usage in programs.'''
	if program[0] == "NEW":
		vtype = varTypeRec(program[2], ctype, vtype)
	elif program[0] in ["REP", "SELECT_ONE"]:
		vtype = varTypeRec(program[1], ctype, vtype)
	elif program[0] in ["SEQ", "PAR"]:
		for p in program[1]:
			vtype = varTypeRec(p, ctype, vtype)
	elif program[0] == "GC":
		# Peel off any probes in the guard before using the expression
		if program[1][0] == "GUARD_p":
			pass
		elif program[1][0] == "GUARD":
			done = False
			while not done:
				if e[0] == "EXPR":
					done = True
				elif e[0] == "GUARD_p":
					break
				elif e[0] == "GUARD":
					e = e[2]
				else:
					print e
					assert False
			if done:	
				vtype = varTypeExpr(e, vtype)
		elif program[1] == "ELSE":
			pass
		else:
			assert program[1][0] == "GUARD_e"
			vtype = varTypeExpr(program[1][1], vtype)
		vtype = varTypeRec(program[2], ctype, vtype)
	elif program[0] in ["SELECT_DET", "SELECT_UDET"]:
		for gc in program[1][1]:
			vtype = varTypeRec(gc, ctype, vtype)
	elif program[0] == "COMM":
		vtype = varTypeComm(program[1], ctype, vtype)
	elif program == "SKIP":
		pass
	else:
		raise NotImplementedError("Unexpected Program: " + str(program))
	return vtype

def varType(program, ctype):
	'''Identifies each variable as storing a boolean (V_BOOL), an integer (V_BOOL),
  or some unknown data value (V_DATA). Infers the type from usage.'''
	return varTypeRec(program, ctype, {})

def updateMapChan(map_chan, val, update):
	'''Replaces all instances of val with the updated version in the values
  of map_chan'''
	for k in map_chan.keys():
		m = map_chan[k]
		if val in m:
			m.remove(val)
			m.add(update)
		map_chan[k] = m
	return map_chan

def searchMapChan(map_chan, val):
	'''Searches for val in the values of map_chan'''
	for k in map_chan.keys():
		if val in map_chan[k]:
			return True

	return False

def newParRec(program, new):
	if program[0] == "NEW":
		p = newParRec(program[2], new+[program[1]])
		return p
	elif program[0] == "PAR":
		if len(new) == 0:
			return program
		# Map from each channel in new to the programs its used in
		map_chan = {}
		mapped = set([])
		for i in range(len(program[1])):
			p = program[1][i]
			chans = set(collectChans(p))
			for n in new:
				if n in chans:
					mapped.add(i)
					if n not in map_chan:
						map_chan[n] = set([i])
					else:
						map_chan[n].add(i)
		# Collect up the programs that aren't mapped to
		no_map = []
		for i in range(len(program[1])):
			if i not in mapped:
				no_map.append(i)
		# While there is an element of map_chan with exactly two entries, group
    #  them together with NEW_PAR, and then update map_chan. If, after updating,
    #  the new entitiy isn't in any of the elements of map_chan, put it into 
    #  no_map
		run = True
		while run:
			run = False
			for k in map_chan.keys():
				if len(map_chan[k]) == 2:
					run = True
					break
			if run:
				lmck = list(map_chan[k])
				p = ("NEW_PAR", k, program[1][lmck[0]], program[1][lmck[1]])
				program[1].append(p)
				map_chan = updateMapChan(map_chan, lmck[0], len(program[1])-1)
				map_chan = updateMapChan(map_chan, lmck[1], len(program[1])-1)
				map_chan.pop(k)
				in_map = searchMapChan(map_chan, len(program[1])-1)
				if not in_map:
					no_map.append(len(program[1])-1)
		# Reconstruct the program
		while len(map_chan) > 0:
			k = map_chan.keys()[0]
			if len(map_chan[k]) == 1:
				pi = list(map_chan[k])[0]
				pp = ("NEW", k, program[1][pi])
				ps = [pi]
			else:
				ps = list(map_chan[k])
				pps = [program[1][pi] for pi in ps]
 				pp = ("NEW", k, ("PAR", pps, 0))
			program[1].append(pp)
			for pi in ps:
				map_chan = updateMapChan(map_chan, pi, len(program[1])-1)
			map_chan.pop(k)
			in_map = searchMapChan(map_chan, len(program[1])-1)
			if not in_map:
				no_map.append(len(program[1])-1)
		p = None
		if len(no_map) > 0:
			if len(no_map) == 1:
				p = program[1][no_map[0]]
			else:
				ps = [ program[1][i] for i in no_map ]
				p = ("PAR", ps, 0)
		else:
			assert False
		return p
	else:
		return program

def newPar(program):
	'''Takes in a program, and collapses the uses of NEW and PAR together, 
  creating NEW_PAR nodes.  Each of these nodes has exactly two children'''
	return newParRec(program, [])

def wideParRec(program, new):
	if program[0] == "NEW":
		return wideParRec(program[2], new+[program[1]])
	elif program[0] == "PAR":
		map_chan = {}
		mapped = set([])
		for i in range(len(program[1])):
			p = program[1][i]
			chans = set(collectChans(p))
			for n in new:
				if n in chans:
					mapped.add(i)
					if n in map_chan:
						map_chan[n].add(i)
					else:
						map_chan[n] = set([i])
		return ("WIDE_PAR", map_chan, program[1])
	else:
		return program


def widePar(program):
	'''Takes in a program, and collapses the uses of NEW and PAR together,
  creating WIDE_PAR nodes.  Each of these has two children: first, a 
  a map from channels to processes, and second, a list of processes'''
	return wideParRec(program, [])
