import numpy;
import math;
from pylab import *;
import sys;
from matplotlib.colors import ColorConverter

plot_portals = 1

def self_omit(cluster_id, current_cluster_id):
	if cluster_id < current_cluster_id:
		return cluster_id
	else:
		return cluster_id + 1

def main():
	if len(sys.argv) < 2:
		print "Error: not enough arguments."
		print "Usage: python plot.py embedding_file {tag_file tag_hash_file topTagCount=3}"
		sys.exit(1)
		
	ebdfilename = sys.argv[1];
	if len(sys.argv)>2:
		doTags=1;
		tagFilename=sys.argv[2];
		tagHashFilename=sys.argv[3];
	else:
		doTags=0;
		
	if len(sys.argv)==5:
		topTagCount=int(sys.argv[4]);
	else:
		topTagCount=3;
		    
	# read ebd header
	f=open(ebdfilename, 'r');
	tokens=f.readline().split(' ');  
	numClusters=int(tokens[1]);    
	tokens=f.readline().split(' ');  
	numSongs=int(tokens[1]);
	tokens=f.readline().split(' ');  
	dimensions=int(tokens[1]);
	if dimensions != 2:
		print "Error: can only plot embedding file with dimensionality 2."
		sys.exit(1)
	tokens=f.readline().split(' ');  
	bias_enabled=int(tokens[1]);
	tokens=f.readline().split(' ');  
	inter_cluster_transition_type=int(tokens[1]);
	numPortals = 0
	if inter_cluster_transition_type == 3 or inter_cluster_transition_type == 4 or inter_cluster_transition_type == 5 or inter_cluster_transition_type == 6:
		tokens=f.readline().split(' ');  
		numPortals=int(tokens[1]);


	# read cluster membership for each song
	tokens=f.readline().split(' ');
	clusters=numpy.zeros((numSongs, 1));
	clusterSizes=numpy.zeros((numClusters, 1));
	for i in range(1, numSongs+1):
		cluster=int(tokens[i]);
		clusters[i-1]=cluster;
		clusterSizes[cluster]=clusterSizes[cluster]+1;

	# read coordinates for each song
	if inter_cluster_transition_type == 0:
		coords=numpy.zeros((numSongs+numClusters*numClusters, 2));
	elif inter_cluster_transition_type == 1:
		coords=numpy.zeros((numSongs+2 * numClusters * (numClusters - 1), 2));
	elif inter_cluster_transition_type == 3 or inter_cluster_transition_type == 4 or inter_cluster_transition_type == 6:
		coords=numpy.zeros((numSongs + numPortals * numClusters, 2));
	elif inter_cluster_transition_type == 5:
		coords=numpy.zeros((numSongs + 2 * numPortals * numClusters, 2));

	counter=0;
	for i in range(numClusters):
		line=f.readline();
		if bias_enabled:
			line=f.readline();
		if inter_cluster_transition_type == 0:
			temp_range =clusterSizes[i] + numClusters
		elif inter_cluster_transition_type == 1:
			temp_range =clusterSizes[i] + 2 * (numClusters - 1)
		elif inter_cluster_transition_type == 3 or inter_cluster_transition_type == 4 or inter_cluster_transition_type == 6:
			temp_range = clusterSizes[i] + numPortals
		elif inter_cluster_transition_type == 5:
			temp_range = clusterSizes[i] + 2 * numPortals

		for j in range(temp_range):
			tokens=f.readline().split(' ');        
			for k in range(dimensions):
				coords[counter, k]=float(tokens[k]);            
			counter=counter+1;
	f.close();
        
	if doTags:
		# read tag hash file
		f=open(tagHashFilename, 'r');
		tagHash=[];
		while 1:
			line=f.readline();
			if not line:
				break;   
			tokens=line.split(', ');
			tagHash.append(tokens[1].strip('\n'));	
		f.close();
		
		# read tag file
		numTags=len(tagHash);
		tagCountPerCluster=numpy.zeros((numTags, numClusters));
		f=open(tagFilename, 'r');
		for i in range(numSongs):
			line=f.readline();
			if line=='#\n':
				continue;
			tokens=line.split(' ');
			for j in range(len(tokens)):
				cluster=int(clusters[i]);
				tag=int(tokens[j]);						
				tagCountPerCluster[tag, cluster]+=1;
		f.close();    
		
		# find top N tags for each cluster
		topTags=zeros((topTagCount, numClusters));
		for i in range(numClusters):
			idx=argsort(tagCountPerCluster[:, i]);		
			for j in range(topTagCount):
				topTags[j, i]=idx[numTags-j-1];

	# index positions for each cluster in coords
	coordIdx=numpy.zeros((numClusters, 2));
	if inter_cluster_transition_type == 0:
		coordIdx[0, 1]=clusterSizes[0]+numClusters-1;
	elif inter_cluster_transition_type == 1:
		coordIdx[0, 1]=clusterSizes[0]+ 2 * (numClusters - 1) - 1;
	elif inter_cluster_transition_type == 3 or inter_cluster_transition_type == 4 or inter_cluster_transition_type == 6:
		coordIdx[0, 1]=clusterSizes[0]+ numPortals - 1;
	elif inter_cluster_transition_type == 5:
		coordIdx[0, 1]=clusterSizes[0]+ 2 * numPortals - 1;
	for i in range(1, numClusters):
		coordIdx[i, 0]=coordIdx[i-1, 1]+1;
		if inter_cluster_transition_type == 0:
			coordIdx[i, 1]=coordIdx[i, 0]+clusterSizes[i]+numClusters-1;    
		elif inter_cluster_transition_type == 1:
			coordIdx[i, 1]=coordIdx[i, 0]+clusterSizes[i]+ 2 * (numClusters - 1) - 1;    
		elif inter_cluster_transition_type == 3 or inter_cluster_transition_type == 4 or inter_cluster_transition_type == 6:
			coordIdx[i, 1]=coordIdx[i, 0]+clusterSizes[i]+ numPortals - 1;    
		elif inter_cluster_transition_type == 5:
			coordIdx[i, 1]=coordIdx[i, 0]+clusterSizes[i]+ 2 * numPortals - 1;    

	# scatter plot
	fig=figure();
	area1=pi*4**2;area2=pi*4**2;
	M=floor(sqrt(numClusters));N=ceil(numClusters/M);
	for i in range(numClusters):
		startIdx=int(coordIdx[i, 0]);endIdx=int(coordIdx[i, 1]);
		X=coords[range(startIdx, endIdx+1), 0].copy();
		Y=coords[range(startIdx, endIdx+1), 1].copy();

		subplot(M, N, i+1);
		titleStr='Cluster '+str(i);
		if doTags:
			titleStr+=', Top '+str(topTagCount)+' tags:\n';
			for j in range(int(topTagCount)):
				if j==topTagCount-1:
					titleStr+=tagHash[int(topTags[j, i])];
				else:
					titleStr+=tagHash[int(topTags[j, i])]+', ';
					
		myfont = { 'fontsize':10 }
		title(titleStr, **myfont);		
		hold(1);		


		cc = ColorConverter()
		#songcolor = cc.to_rgb("#9932CC")
		songcolor = cc.to_rgb("#BBBBBB")
		entrycolor = cc.to_rgb('b')
		exitcolor = cc.to_rgb("#FF7F00");

		if inter_cluster_transition_type == 0:
			scatter(X[range(numClusters, X.shape[0])], Y[range(numClusters, Y.shape[0])], s=area2, marker='.', c=  songcolor, lw=0);		
			scatter(X[range(numClusters)], Y[range(numClusters)], s=0, marker='^', c='b');		
			Xlist = list(X[range(numClusters)])
			Ylist = list(Y[range(numClusters)]) 
			labels = map(lambda x: str(x), range(len(Xlist)));
			if plot_portals:
				for label, x, y in zip(labels, Xlist, Ylist):
					if int(label) == i:
						annotate(label, xy = (x, y), textcoords = "data", color=entrycolor, size = 14)
					else:
						annotate(label, xy = (x, y), textcoords = "data", color=exitcolor, size = 14)
		elif inter_cluster_transition_type == 1:
			scatter(X[range(2 * (numClusters - 1), X.shape[0])], Y[range(2 * (numClusters - 1), Y.shape[0])], s=area2, marker='.', c =  songcolor, lw=0);		
			scatter(X[range(2 * (numClusters - 1))], Y[range(2 * (numClusters - 1))], s=0, marker='^', c='b');
			Xlist_entry = list(X[range(numClusters - 1)])
			Ylist_entry = list(Y[range(numClusters - 1)]) 
			Xlist_exit = list(X[range(numClusters - 1, 2 * (numClusters - 1))])
			Ylist_exit = list(Y[range(numClusters - 1, 2 * (numClusters - 1))]) 
			labels = map(lambda x: str(self_omit(x, i)), range(len(Xlist_entry)));

			if plot_portals:
				for label, x, y in zip(labels, Xlist_entry, Ylist_entry):
					annotate(label, xy = (x, y), textcoords = "data", color=entrycolor, size = 14)
				for label, x, y in zip(labels, Xlist_exit, Ylist_exit):
					annotate(label, xy = (x, y), textcoords = "data", color=exitcolor, size = 14)
		elif inter_cluster_transition_type == 3 or inter_cluster_transition_type == 4 or inter_cluster_transition_type == 6:
			scatter(X[range(numPortals, X.shape[0])], Y[range(numPortals, Y.shape[0])], s=area2, marker='.', c=  songcolor, lw=0);		
			scatter(X[range(numPortals)], Y[range(numPortals)], s=20, marker='.', c='b', edgecolor = 'b');		
		elif inter_cluster_transition_type == 5:
			scatter(X[range(2 * numPortals, X.shape[0])], Y[range(2 * numPortals, Y.shape[0])], s=area2, marker='.', c=  songcolor, lw=0);		
			scatter(X[range(numPortals)], Y[range(numPortals)], s=area2, marker='.', c= entrycolor, edgecolor = entrycolor);		
			scatter(X[range(numPortals, 2 * numPortals)], Y[range(numPortals, 2 * numPortals)], s=area2, marker='.', c= exitcolor, edgecolor = exitcolor);		


		hold(0);
	subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=0.5)
	#savefig(bbox_inches='tight')
	show();

if(__name__ == "__main__"):
	main()
