#include <iostream>
#include "GPUSupport.h"


int global_id[3], global_size[3];
#pragma omp threadprivate(global_id)

#define get_global_id(x) global_id[x]
#define get_global_size(x) global_size[x]

// change (int2) (float3) etc. types to just int2 float3



// number of iterations
__constant int maxIterations = 50;

// eye position
__constant float3 eyePos = float3(0.0f, 0.0f, -15.0f);

// maximum ray length from the eye
__constant float maxRayLength = 100;  

// maximum number of steps to take along the ray
__constant int maxSteps = 100;

// light parameters
__constant float3 ambient = float3(0.1f, 0.1f, 0.1f);
__constant float3 lightPos = float3(5.0f, 5.0f, -10.0f);

// distance estimation constants
__constant float fixedRadius2 = 1.0f * 1.0f;
__constant float minRadius2 = 0.5f * 0.5f;
__constant float3 boxFoldColor = float3(0.1f, 0.0f, 0.0f);
__constant float3 sphereFoldColor = float3(0.0f, 0.1f, 0.0f);

typedef struct {
	float distance;
	float3 color;
} DEInfo;


DEInfo DistanceEstimation(float3 z0, float scale) {
	DEInfo info;
	info.distance = 0.0f;
	info.color = 0;

	float3 c = z0;
	float3 z = z0;
	float factor = scale;

	int n;
	for (n = 0; n < maxIterations; n++) {
		if (z.x > 1.0f) { z.x = 2.0f - z.x; info.color+=boxFoldColor; }
		else if (z.x < -1.0f) { z.x = -2.0f - z.x; info.color+=boxFoldColor; }

		if (z.y > 1.0f) { z.y = 2.0f - z.y; info.color+=boxFoldColor; }
		else if (z.y < -1.0f) { z.y = -2.0f - z.y; info.color+=boxFoldColor; }

		if (z.z > 1.0f) { z.z = 2.0f - z.z; info.color+=boxFoldColor; }
		else if (z.z < -1.0f) { z.z = -2.0f - z.z; info.color+=boxFoldColor; }

		float r = length(z);
		float r2 = r*r;

		if (r2 < minRadius2) {
			z = (z * fixedRadius2) / minRadius2;
			factor = (factor * fixedRadius2) / minRadius2;
			info.color+=sphereFoldColor;
		}
		else if (r2 < fixedRadius2) {
			z = (z * fixedRadius2) / r2;
			factor = (factor * fixedRadius2) / r2;
			info.color+=sphereFoldColor;
		}

		z = (z * scale) + c;
		factor *= scale;
		r = length(z);

		info.distance = r / fabs(factor);
		if (r > 1024)
			break;
	}

	info.color = clamp(info.color, 0, 1);

	if (n == maxIterations)
		info.distance=0;
	return info;
}


float3 Normal(float3 z, float e, float scale) {
	float3 xoff = float3(e, 0.0f, 0.0f), yoff = float3(0.0f, e, 0.0f), zoff = float3(0.0f, 0.0f, e);

	float3 d = float3
		( DistanceEstimation(z + xoff, scale).distance - DistanceEstimation(z - xoff, scale).distance, 
	      DistanceEstimation(z + yoff, scale).distance - DistanceEstimation(z - yoff, scale).distance,
	      DistanceEstimation(z + zoff, scale).distance - DistanceEstimation(z - zoff, scale).distance );

	return normalize(d / (2*e));
}


__kernel void Mandelbox(__read_only float scale, __write_only image2d_t dest_img) {
	// pixel being evaluated by this kernel instance
	int x = get_global_id(0);
	int y = get_global_id(1);
	
	// image dimensions
	int width = get_global_size(0);
	int height = get_global_size(1);

	// view plane ranges from (-5,-5,-5) to (5,5,-5) in world space
	float3 viewPos = float3(10*(x/(float)width)-5, 10*(y/(float)height)-5, -5);
	
	// start tracing a ray from the eye
	float3 ray = eyePos;

	// unit length ray from eye into world
	float3 rayDirection = normalize(viewPos-eyePos);
	
	float rayLength = 0;
	const float EPSILON = 1.0E-6f;
	float eps = EPSILON;
	bool intersected = false;
	DEInfo dei;
	
	for (int i = 0; i < maxSteps; ++i) {
		dei = DistanceEstimation(ray, scale);
		ray += rayDirection*dei.distance;
		rayLength += dei.distance;

		if (rayLength > maxRayLength)
			break;  // exceeded max ray length
		if (dei.distance < eps) {
			intersected = true;  // hit the Mandelbox
			break;
		}
		
		eps = max(EPSILON, (1.0f / 1024) * rayLength);
	}
   
     uint4 pixelColor = 0;  // background is black, and alpha 0 (transparent)
	 if (intersected) {
		// normal at the intersection point
		float3 N = Normal(ray, eps/2, scale);
		
		// compute color (single white light source of intensity 1)
		float3 L = normalize(lightPos-ray);
		float NdotL = dot(N, L);
		float3 diffuse = 0;
		if (NdotL > 0) 
			diffuse = dei.color * NdotL;
			
		float3 color = clamp(diffuse + ambient, 0, 1);

		// bgr order, so reverse x (r) and z (b)
		pixelColor = uint4((uint)(color.z * 255), (uint)(color.y * 255), (uint)(color.x * 255), 255);
	}

	write_imageui(dest_img, int2(x,y), pixelColor);
}



void Render(image2d_t img, int width, int height, float scale) {
	global_size[0] = width;
	global_size[1] = height;
#pragma omp parallel for schedule(dynamic)
	for (int y=0; y<height; y++) {
		global_id[1] = y;
		if (y%100 == 0)
			std::cout << y << std::endl;
		for (int x=0; x<width; x++) {
			global_id[0] = x;
			Mandelbox(scale, img);
		}
	}
}
