import torch
import torch.nn as nn
import torch.sparse
from torch.autograd import Variable
import numpy as np
import sys
from torch.autograd import Function
import math
from scipy import misc 
###############################################################################
# Functions
###############################################################################

# def weights_init(m):
#     classname = m.__class__.__name__
#     if classname.find('Conv') != -1:
#         m.weight.data.normal_(0.0, 0.02)
#     elif classname.find('BatchNorm2d') != -1 or  classname.find('InstanceNorm2d') != -1:
#         m.weight.data.normal_(1.0, 0.02)
#         m.bias.data.fill_(0)

# def get_norm_layer(norm_type):
#     if norm_type == 'batch':
#         norm_layer = nn.BatchNorm2d
#     elif norm_type == 'instance':
#         norm_layer = nn.InstanceNorm2d
#     else:
#         print('normalization layer [%s] is not found' % norm)
#     return norm_layer

# def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, gpu_ids=[]):
#     netG = None
#     use_gpu = len(gpu_ids) > 0
#     norm_layer = get_norm_layer(norm_type=norm)

#     if use_gpu:
#         assert(torch.cuda.is_available())

#     if which_model_netG == 'resnet_9blocks':
#         netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids)
#     elif which_model_netG == 'resnet_6blocks':
#         netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, gpu_ids=gpu_ids)
#     elif which_model_netG == 'unet_128':
#         netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
#     elif which_model_netG == 'unet_256':
#         # netG = SingleUnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
#         # netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
#         netG = MultiUnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout, gpu_ids=gpu_ids)
#     else:
#         print('Generator model name [%s] is not recognized' % which_model_netG)
#     if len(gpu_ids) > 0:
#         netG.cuda(device_id=gpu_ids[0])
#     netG.apply(weights_init)
#     return netG


# def define_D(input_nc, ndf, which_model_netD,
#              n_layers_D=3, norm='batch', use_sigmoid=False, gpu_ids=[]):
#     netD = None
#     use_gpu = len(gpu_ids) > 0
#     norm_layer = get_norm_layer(norm_type=norm)

#     if use_gpu:
#         assert(torch.cuda.is_available())
#     if which_model_netD == 'basic':
#         netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
#     elif which_model_netD == 'n_layers':
#         netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid, gpu_ids=gpu_ids)
#     else:
#         print('Discriminator model name [%s] is not recognized' %
#               which_model_netD)
#     if use_gpu:
#         netD.cuda(device_id=gpu_ids[0])
#     netD.apply(weights_init)
#     return netD


def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print(net)
    print('Total number of parameters: %d' % num_params)


##############################################################################
# Classes
##############################################################################


class JointLoss(nn.Module):
    def __init__(self):
        super(JointLoss, self).__init__()
        self.w_data = 1.0
        self.w_grad = 0.5
        self.w_sm = 2.0
        self.w_od =  0.5
        self.w_od_auto = 0.2
        self.w_sky = 0.1
        # self.h_offset = [0,0,0,1,1,2,2,2,1]
        # self.w_offset = [0,1,2,0,2,0,1,2,1]
        self.total_loss = None

    def Ordinal_Loss(self, prediction_d, targets, input_images):
        total_loss = Variable(torch.cuda.FloatTensor(1))
        total_loss[0] = 0
        n_point_total = 0

        batch_input = torch.exp(prediction_d)

        ground_truth_arr = Variable(targets['ordinal'].cuda(), requires_grad = False)
        # sys.exit()
        for i in range(0, prediction_d.size(0)):
            gt_d = ground_truth_arr[i]
            n_point_total = n_point_total + gt_d.size(0)
            # zero index!!!!
            x_A_arr = targets['x_A'][i]
            y_A_arr = targets['y_A'][i]
            x_B_arr = targets['x_B'][i]
            y_B_arr = targets['y_B'][i]

            inputs = batch_input[i,:,:]

            # o_img = input_images[i,:,:,:].data.cpu().numpy()
            # o_img = np.transpose(o_img, (1,2,0))

            # store_img = inputs.data.cpu().numpy()
            # misc.imsave(targets['path'][i].split('/')[-1] + '_p.jpg', store_img)
            # misc.imsave(targets['path'][i].split('/')[-1] + '_o.jpg', o_img)

            z_A_arr = inputs[y_A_arr ,x_A_arr]
            z_B_arr = inputs[y_B_arr ,x_B_arr]

            inner_loss = torch.mul(-gt_d, (z_A_arr   - z_B_arr) ) 

            if inner_loss.data[0] > 5:
                print('DIW difference is too large !!!!')
                # inner_loss = torch.mul(-gt_d, (torch.log(z_A_arr)   - torch.log(z_B_arr) ) )  
                return 5

            ordinal_loss = torch.log(1 + torch.exp(inner_loss ))

            total_loss = total_loss + torch.sum(ordinal_loss)


        if total_loss.data[0] != total_loss.data[0]:
            print("SOMETHING WRONG !!!!!!!!!!", total_loss.data[0])
            sys.exit()

        return total_loss/n_point_total

    def sky_loss(self, prediction_d, targets, i):
        tau = 4
        total_loss = Variable(torch.cuda.FloatTensor(1))
        total_loss[0] = 0
        gt_d = 1
        # inverse depth 
        inputs = torch.exp(prediction_d)
        x_A_arr = targets['sky_x'][i,0]
        y_A_arr = targets['sky_y'][i,0]
        x_B_arr = targets['depth_x'][i,0]
        y_B_arr = targets['depth_y'][i,0]

        z_A_arr = inputs[y_A_arr ,x_A_arr]
        z_B_arr = inputs[y_B_arr ,x_B_arr]

        inner_loss = -gt_d * (z_A_arr   - z_B_arr)  

        if inner_loss.data[0] > tau:
            print("sky prediction reverse")
            inner_loss = -gt_d * (torch.log(z_A_arr)   - torch.log(z_B_arr) )  

        ordinal_loss = torch.log(1 + torch.exp(inner_loss ))
        return torch.sum(ordinal_loss)

    def Ordinal_Loss_AUTO(self, prediction_d, targets, i):
        tau = 1.2
        total_loss = Variable(torch.cuda.FloatTensor(1))
        total_loss[0] = 0
        # n_point_total = 0

        inputs = torch.exp(prediction_d)
        gt_d = targets['ordinal'][i,0]

        x_A_arr = targets['x_A'][i,0]
        y_A_arr = targets['y_A'][i,0]
        x_B_arr = targets['x_B'][i,0]
        y_B_arr = targets['y_B'][i,0]

        z_A_arr = inputs[y_A_arr ,x_A_arr]
        z_B_arr = inputs[y_B_arr ,x_B_arr]

        # A is close, B is further away
        inner_loss = -gt_d * (z_A_arr   - z_B_arr)  

        ratio = torch.div(z_A_arr, z_B_arr)

        if ratio.data[0] > tau:
            print("DIFFERNCE IS TOO LARGE, REMOVE OUTLIERS!!!!!!")
            return 1.3873
        else:
            ordinal_loss = torch.log(1 + torch.exp(inner_loss ))
            return torch.sum(ordinal_loss)

    def GradientLoss(self, log_prediction_d, mask, log_gt):
        N = torch.sum(mask)
        log_d_diff = log_prediction_d - log_gt
        log_d_diff = torch.mul(log_d_diff, mask)

        v_gradient = torch.abs(log_d_diff[0:-2,:] - log_d_diff[2:,:])
        v_mask = torch.mul(mask[0:-2,:], mask[2:,:])
        v_gradient = torch.mul(v_gradient, v_mask)

        h_gradient = torch.abs(log_d_diff[:,0:-2] - log_d_diff[:,2:])
        h_mask = torch.mul(mask[:,0:-2], mask[:,2:])
        h_gradient = torch.mul(h_gradient, h_mask)

        gradient_loss = torch.sum(h_gradient) + torch.sum(v_gradient)
        gradient_loss = gradient_loss/N

        return gradient_loss

    def Data_Loss(self, log_prediction_d, mask, log_gt):
        N = torch.sum(mask)
        log_d_diff = log_prediction_d - log_gt
        log_d_diff = torch.mul(log_d_diff, mask)
        s1 = torch.sum( torch.pow(log_d_diff,2) )/N 
        s2 = torch.pow(torch.sum(log_d_diff),2)/(N*N)  
        data_loss = s1 - s2
        
        return data_loss

    def Data_Loss_test(self,prediction_d, targets):
        mask = targets['mask'].cuda()
        d_gt = targets['gt'].cuda()
        total_loss = Variable(torch.cuda.FloatTensor(1))
        total_loss[0] = 0
        k = 0.5
        for i in range(0, mask.size(0)):
            # number of valid pixels
            N = torch.sum(mask[i,:,:],0)
            d_log_gt = torch.log(d_gt[i,:,:])
            log_d_diff = prediction_d[i,:,:] - d_log_gt
            log_d_diff = torch.cmul(log_d_diff, mask)

            data_loss = (torch.sum(torch.pow(log_d_diff,2))/N - torch.pow(torch.sum(log_d_diff),2)/(N*N)  )
 
            total_loss = total_loss + data_loss

        return total_loss/mask.size(0)


    def __call__(self, input_images, prediction_d,targets, is_DIW, current_epoch):
        # num_features_d = 5

        # prediction_d_un = prediction_d.unsqueeze(1)
        prediction_d_1 = prediction_d[:,::2,::2]
        prediction_d_2 = prediction_d_1[:,::2,::2]
        prediction_d_3 = prediction_d_2[:, ::2,::2]

        mask_0 = Variable(targets['mask_0'].cuda(), requires_grad = False)
        d_gt_0 = torch.log(Variable(targets['gt_0'].cuda(), requires_grad = False))
        
        mask_1 = Variable(targets['mask_1'].cuda(), requires_grad = False)
        d_gt_1 = torch.log(Variable(targets['gt_1'].cuda(), requires_grad = False))

        mask_2 = Variable(targets['mask_2'].cuda(), requires_grad = False)
        d_gt_2 = torch.log(Variable(targets['gt_2'].cuda(), requires_grad = False))

        mask_3 = Variable(targets['mask_3'].cuda(), requires_grad = False)
        d_gt_3 = torch.log(Variable(targets['gt_3'].cuda(), requires_grad = False))

        total_loss = Variable(torch.cuda.FloatTensor(1))
        total_loss[0] = 0
        count = 0 

        for i in range(0, mask_0.size(0)):
            # print(i, targets['has_ordinal'][i, 0]) 
            if targets['has_ordinal'][i, 0] > 0.1:
                continue 
            else:
                total_loss += self.w_data * self.Data_Loss(prediction_d[i,:,:], mask_0[i,:,:], d_gt_0[i,:,:]) 
                total_loss += self.w_grad * self.GradientLoss(prediction_d[i,:,:] , mask_0[i,:,:], d_gt_0[i,:,:]) 
                total_loss += self.w_grad * self.GradientLoss(prediction_d_1[i,:,:] , mask_1[i,:,:], d_gt_1[i,:,:])
                total_loss += self.w_grad * self.GradientLoss(prediction_d_2[i,:,:], mask_2[i,:,:], d_gt_2[i,:,:])
                total_loss += self.w_grad * self.GradientLoss(prediction_d_3[i,:,:], mask_3[i,:,:], d_gt_3[i,:,:])
                count += 1

        if count == 0:
            count = 1

        total_loss = total_loss/count

        self.total_loss = total_loss

        return total_loss.data[0]


    def get_loss_var(self):
        return self.total_loss

