import numpy as np
import pdb
import pylab as pl
import matplotlib.image as mpimg

# Author: Mevlana C. Gemici, Cornell CS4786/5786, Spring 2015

###########################  IMAGE/PLOT Helpers ###########################

def showImage(image,FigSize=(7, 7)):
        # Display image (a numpy matrix of size (m x n) or (m x n x 3) with range either 0-1 or 0-255) 
        fig = pl.figure(figsize=FigSize)
        if np.any(image<0) or np.any(image>255): print 'Bad Image Range'; return 
        if np.all(image<1): image=image*255
        if (len(image.shape)==2): imgplot1 = pl.imshow((image).astype(np.uint8),cmap='gray',interpolation='None')
        elif (len(image.shape)==3 and image.shape[2]==3): imgplot1 = pl.imshow((image).astype(np.uint8),interpolation='None')
        else: print 'Bad Image Shape'; return 
        pl.tick_params(axis='x', which='both', bottom='off', top='off', labelbottom='off') 
        pl.tick_params(axis='y', which='both', right='off', left='off', labelleft='off') 
        pl.show(block=False)    
        pl.close()

def saveImage(image,address='./currentfig.png'):
        # save image to address (a numpy matrix of size (m x n) or (m x n x 3) with range either 0-1 or 0-255) 
        if np.any(image<0) or np.any(image>255): print 'Bad Image Range'; return 
        if np.all(image<1): image=image*255
        if (len(image.shape)==2): mpimg.imsave(address, image, cmap='gray')
        elif (len(image.shape)==3 and image.shape[2]==3): mpimg.imsave(address, image)
        else: print 'Bad Image Shape'; return 

def saveImage2(image,address='./currentfig.png'):
        # save a "pixel" image to address (a numpy matrix of size (m x n) or (m x n x 3) with range either 0-1 or 0-255). Uses a different library and quality of images are different usually. 
        if np.any(image<0) or np.any(image>255): print 'Bad Image Range'; return 
        if np.all(image<1): image=image*255
        if (len(image.shape)==2): pl.imshow(image, cmap='gray',interpolation='None')
        elif (len(image.shape)==3 and image.shape[2]==3): pl.imshow(image, interpolation='None')
        else: print 'Bad Image Shape'; return 
        pl.tick_params(axis='x', which='both', bottom='off', top='off', labelbottom='off') 
        pl.tick_params(axis='y', which='both', right='off', left='off', labelleft='off') 
        pl.savefig(address,format='png')
        pl.close()

def saveStemPlot(array,x=None,address='./currentfigstemplot.png'):
        # save a stem plot of the array to address (a numpy array) 
        pl.clf()
        if x is None: x = np.arange(array.shape[0])
        markerline1, stemlines1, baseline = pl.stem(x,array)
        setp(markerline1, 'markerfacecolor', 'b')
        setp(stemlines1, 'color','b', 'linewidth', 2)
        pl.xticks()
        pl.savefig(address,format='png')
        pl.close()

def savePlot(array,address='./currentfigplot.png'):
        # save a line plot of the array to address (a numpy array) 
        pl.clf()
        fig2 = matplotlib.pyplot.figure(figsize=(10.0, 10.0))
        x = np.arange(array.shape[0])
        pl.plot(x,array)
        pl.savefig(address,format='png')
        pl.close()
