precision highp float;

varying vec2 vUv;

uniform sampler2D inputTex;
uniform sampler2D adjustmentCurveTex;
// uniform sampler2D adjustmentMapTex;

uniform vec2 xScale;

/** Sampled planckian locus data*/
uniform int dataSize;
uniform vec2 plCurveArray[302];
uniform vec2 plCurveNormArray[302];

//uniform float maxLineCurveDist;
uniform vec2 bPassed;
uniform float lineSide;

// Line data
vec2 lineStart;
vec2 lineEnd;

// Blending data
const float width = 0.2;
vec2 b;
vec2 boarderDirection;

vec2 uvOnCurve;
vec2 uvOnLine;

vec2 uv; // original uv
float illuminance;
float lineProjLength; // length of projection on line
vec2 lineVector; // original uv - nearest uv on line
float curveProjLength; // length of projection on curve
float curveDistance;

vec2 nearestPointOnCurve;

const float CIE_E = 216.0 / 24389.0;
const vec3 LAB_REF_WHITE = vec3(95.047, 100, 108.883) / 100.0;

vec3 rgb2xyz(vec3 c) {
    vec3 tmp;
    tmp.x = ( c.r > 0.04045 ) ? pow( ( c.r + 0.055 ) / 1.055, 2.4 ) : c.r / 12.92;
    tmp.y = ( c.g > 0.04045 ) ? pow( ( c.g + 0.055 ) / 1.055, 2.4 ) : c.g / 12.92,
    tmp.z = ( c.b > 0.04045 ) ? pow( ( c.b + 0.055 ) / 1.055, 2.4 ) : c.b / 12.92;
    return tmp * mat3(
        0.4124, 0.3576, 0.1805,
        0.2126, 0.7152, 0.0722,
        0.0193, 0.1192, 0.9505
    );
}

vec3 xyz2lab(vec3 c) {
    vec3 n = c / LAB_REF_WHITE;
    vec3 v;
    v.x = ( n.x > CIE_E ) ? pow( n.x, 1.0 / 3.0 ) : ( 7.787 * n.x ) + ( 16.0 / 116.0 );
    v.y = ( n.y > CIE_E ) ? pow( n.y, 1.0 / 3.0 ) : ( 7.787 * n.y ) + ( 16.0 / 116.0 );
    v.z = ( n.z > CIE_E ) ? pow( n.z, 1.0 / 3.0 ) : ( 7.787 * n.z ) + ( 16.0 / 116.0 );
    return vec3(
        116.0 * v.y - 16.0,
        500.0 * ( v.x - v.y ),
        200.0 * ( v.y - v.z )
    );
}

vec3 rgb2lab(vec3 c) {
    vec3 lab = xyz2lab( rgb2xyz( c ) );
    return vec3( lab.x / 100.0, 0.5 + 0.5 * ( lab.y / 127.0 ), 0.5 + 0.5 * ( lab.z / 127.0 ));
}

vec3 lab2xyz(vec3 c) {
    float fy = ( c.x + 16.0 ) / 116.0;
    float fx = c.y / 500.0 + fy;
    float fz = fy - c.z / 200.0;

    float fx3 = pow(fx, 3.0);
    float fy3 = pow(fy, 3.0);
    float fz3 = pow(fz, 3.0);

    return LAB_REF_WHITE * vec3(
        ( fx3 > CIE_E ) ? fx3 : ( fx - 16.0 / 116.0 ) / 7.787,
        ( fy3 > CIE_E ) ? fy3 : ( fy - 16.0 / 116.0 ) / 7.787,
        ( fz3 > CIE_E ) ? fz3 : ( fz - 16.0 / 116.0 ) / 7.787
    );
}

vec3 xyz2rgb(vec3 c) {
    vec3 v =  c * mat3(
        3.2406, -1.5372, -0.4986,
        -0.9689, 1.8758, 0.0415,
        0.0557, -0.2040, 1.0570
    );
    vec3 r;
    r.x = ( v.r > 0.0031308 ) ? (( 1.055 * pow( v.r, ( 1.0 / 2.4 ))) - 0.055 ) : 12.92 * v.r;
    r.y = ( v.g > 0.0031308 ) ? (( 1.055 * pow( v.g, ( 1.0 / 2.4 ))) - 0.055 ) : 12.92 * v.g;
    r.z = ( v.b > 0.0031308 ) ? (( 1.055 * pow( v.b, ( 1.0 / 2.4 ))) - 0.055 ) : 12.92 * v.b;
    return r;
}

vec3 lab2rgb(vec3 c) {
    return xyz2rgb( lab2xyz( vec3(100.0 * c.x, 2.0 * 127.0 * (c.y - 0.5), 2.0 * 127.0 * (c.z - 0.5)) ) );
}

void calcLineTemp() {
    vec2 a = lineEnd - lineStart;
    vec2 b = uv - lineStart;

    // compute b's projection on a
    vec2 proj = clamp(dot(b, a) / dot(a, a), 0.0, 1.0) * a;

    uvOnLine = proj + lineStart;
    lineProjLength = proj.x / a.x;
    lineVector = uv - uvOnLine;
}

void calcProjCurve(int i0, int i1, vec2 q0, vec2 q1) {
    uvOnCurve = (q0 + q1)/2.0;
    vec2 p = uv;
    vec2 n0 = plCurveNormArray[i0];
    vec2 n1 = plCurveNormArray[i1];

    if (i0 == i1) {
        curveProjLength = float(i0)/float(dataSize-1);
        curveDistance = distance(p, q0);
        return;
    }

    float a = n0.y * n1.x - n0.x * n1.y;
    float b = n0.y * (q1.x - p.x) + n1.x * (q0.y - p.y) - n0.x * (q1.y - p.y) - n1.y * (q0.x - p.x);
    float c = q0.y * (q1.x - p.x) - p.y * q1.x - q0.x * (q1.y - p.y) + p.x * q1.y;

    float delta = b * b - 4.0 * a * c;
    float dist = (-b - sqrt(delta)) / 2.0 / a;

    curveDistance = clamp(dist, -1.0, 1.0);
    vec2 p0 = q0 + n0 * curveDistance;
    vec2 p1 = q1 + n1 * curveDistance;
    float t = clamp(distance(p, p0)/distance(p1, p0), 0.0, 1.0);

    curveProjLength = (float(i0)+t)/float(dataSize-1);
}

float distToLine(vec2 p, vec2 s, vec2 v) {
  vec2 n = vec2(v.y, -v.x);
  return abs(dot(p-s, n));
}

void calcCurveTemp() {
    int minIndex;
    float minDistance = 1.0;
    vec2 uvNearest;
    vec2 normal;
    for (int i=0; i < dataSize; i++) {
        vec2 uvSampled = plCurveArray[i];
        vec2 uvNormSampled = plCurveNormArray[i];
        float dist = distToLine(uv, uvSampled, uvNormSampled);
        if (dist < minDistance) {
            minDistance = dist;
            minIndex = i;
            uvNearest = uvSampled;
            normal = plCurveNormArray[i];
        }
    }

    int leftIdx = max(0, minIndex-1);
    int rightIdx = min(dataSize-1, minIndex+1);

    vec2 leftUv = plCurveArray[leftIdx];
    vec2 rightUv = plCurveArray[rightIdx];

    vec2 leftNorm = plCurveNormArray[leftIdx];
    vec2 rightNorm = plCurveNormArray[rightIdx];

    if (distToLine(uv, leftUv, leftNorm) < distToLine(uv, rightUv, rightNorm)) {
        calcProjCurve(leftIdx, minIndex, leftUv, uvNearest);
    } else {
        calcProjCurve(minIndex, rightIdx, uvNearest, rightUv);
    }
}

float rescale(float x, float min, float max) {
    float rescaled = (x - min) / (max - min);
    return clamp(rescaled, 0.0, 1.0);
}

float rgbAverage(vec4 color) {
    return (color.x + color.y + color.z)/3.0;
}

float getShift() {
    // Compute amount of shift
    // vec4 adjustmentPixel = texture2D(adjustmentMapTex, vUv);
    // float adjuetmentAverage = rgbAverage(adjustmentPixel);
    // adjuetmentAverage = rescale(adjuetmentAverage, xScale.x, xScale.y);

    float adjuetmentAverage = illuminance;
    vec2 coor = vec2(adjuetmentAverage, 0.0);
    float splineVal = texture2D(adjustmentCurveTex, coor).x;
    return splineVal;
}

vec2 shiftLine(float shift) {
    float t = lineProjLength + shift;
    t = clamp(t, 0.0, 1.0);
    vec2 uvInterpolated = lineStart * (1.0 - t) + lineEnd * t;
    vec2 uvNew = uvInterpolated + lineVector;
    return uvNew;
}

vec2 shiftCurve(float shift) {
    float t = curveProjLength + shift;
    float tScaled = t * float(dataSize - 1);

    int l = int(floor(t));
    int r = min(dataSize-1, int(ceil(t)));
    float alpha = tScaled - float(l);

    nearestPointOnCurve = plCurveArray[l] * (1.0 - alpha) + plCurveArray[r] * alpha;

    vec2 p0 = plCurveArray[l] + curveDistance * plCurveNormArray[l];
    vec2 p1 = plCurveArray[r] + curveDistance * plCurveNormArray[r];
    vec2 uvNew = p0 * (1.0 - alpha) + p1 * alpha;
    return uvNew;
}


vec2 blend(vec2 uvLine, vec2 uvCurve) {
    vec2 intersect = dot(boarderDirection, (uv-b))*boarderDirection + b;
    vec3 a = vec3(uv-intersect, 0.0);
    vec3 b = vec3(boarderDirection, 0.0);
    float distFromBoarder = sign(cross(a, b).z) * lineSide * distance(uv, intersect);
    float lineWeight = distFromBoarder/width/2.0 + 1.0;

    lineWeight = clamp(lineWeight, 0.0, 1.0);

    return lineWeight*uvLine + (1.0-lineWeight)*uvCurve;
}

float atan2(float y, float x) {
    bool s = (abs(x) > abs(y));
    return mix(3.14/2.0 - atan(x,y), atan(y,x), s);
}

vec2 tex2polar(vec2 uv) {
    uv = (uv - 0.5) * 2.0;
    return vec2(
        sqrt(pow(uv.x, 2.0) + pow(uv.y, 2.0)),
        atan(uv.y, uv.x)
    );
}

vec3 projectToGamut(vec3 rgb, float illuminance){
    vec3 labfixed = rgb2lab(rgb);
    labfixed.x = illuminance;
    return clamp(lab2rgb(labfixed), vec3(0.0,0.0,0.0), vec3(1.0,1.0,1.0));
}


// Testing example: simple method of adjusting warmth
// Also color space conversion is working
void main() {
    vec4 color = texture2D(inputTex, vUv);

    // Clamp to avoid artifact in uv calculation
    color = clamp(color, vec4(0.0), vec4(1.0));

    // Check if the spline is reset, i.e. no adjustment needed
    if (texture2D(adjustmentCurveTex, vUv).y < -10.0 ) {
        gl_FragColor = color;
        return;
    }

    lineStart = plCurveArray[0];
    lineEnd = plCurveArray[dataSize-1];
    b = bPassed;
    boarderDirection = normalize(lineEnd - lineStart);

    float alpha = color.a;
    vec3 lab = rgb2lab(color.rgb);
    illuminance = lab.x;
    uv = lab.yz;

    calcLineTemp();
    calcCurveTemp();
    float shiftDist = getShift();
    vec2 uvNewLine = shiftLine(shiftDist);
    vec2 uvNewCurve = shiftCurve(shiftDist);
    vec2 uvBlend = blend(uvNewLine, uvNewCurve);

    vec3 rgb = clamp(lab2rgb(vec3(illuminance, uvBlend)), vec3(0.0,0.0,0.0), vec3(1.0,1.0,1.0));
    for(int i=0;i<25;i++){
        rgb = projectToGamut(rgb, illuminance);
    }

    gl_FragColor = vec4(rgb, alpha);
}
