#include "PNGReader.h"

void png_cexcept_error(png_structp png_ptr, png_const_charp msg) {
	if (png_ptr)
		throw ImageException(msg);
}

void PNGReader::Open(const std::string &fileName) {
	Close();

	_fp = fopen(fileName.c_str(), "rb");
	if (!_fp)
		throw ImageException(std::string("Could not open ") + fileName + std::string(" for reading"));

	char header[8];
	if (fread(header, 1, 8, _fp) != 8)
	  throw ImageException(fileName + std::string(" is not recognized as a PNG file"));
	if (png_sig_cmp((png_bytep)header, 0, 8))
		throw ImageException(fileName + std::string(" is not recognized as a PNG file"));

	_png_ptr = png_create_read_struct(PNG_LIBPNG_VER_STRING, NULL, (png_error_ptr)png_cexcept_error, NULL);

	if (!_png_ptr)
		throw ImageException("png_create_read_struct failed");

	_info_ptr = png_create_info_struct(_png_ptr);
	if (!_info_ptr)
		throw ImageException("png_create_info_struct failed");

	png_init_io(_png_ptr, _fp);
	png_set_sig_bytes(_png_ptr, 8);
	png_read_info(_png_ptr, _info_ptr);
	_width = png_get_image_width(_png_ptr, _info_ptr);
	_height = png_get_image_height(_png_ptr, _info_ptr);

	png_byte bitDepth = png_get_bit_depth(_png_ptr, _info_ptr);
	if ((bitDepth!=8) && (bitDepth!=16))
		throw ImageException("Only 8 and 16 bit images are supported");
	_channelDepth = bitDepth;

	// might also be PNG_COLOR_TYPE_GRAY
	png_byte colorType = png_get_color_type(_png_ptr, _info_ptr);  
	if ((colorType!=PNG_COLOR_TYPE_RGB) && (colorType!=PNG_COLOR_TYPE_RGBA))
		throw ImageException("Only RGB and RGBA color channel images are supported");
	_channelCount = (colorType == PNG_COLOR_TYPE_RGB)?3:4;

	png_read_update_info(_png_ptr, _info_ptr);
}


void PNGReader::Close() {
	if (_png_ptr)
		png_destroy_read_struct(&_png_ptr, (png_infopp)NULL, (png_infopp)NULL);

	if (_fp)
		fclose(_fp);

	_png_ptr = 0;
	_info_ptr = 0;
	_fp = 0;
}


void PNGReader::ReadNextRow(ImageRow &row) {
	// only allocate a new row buffer if the existing one isn't compatible
	if ((row.GetWidth() != _width) || (row.GetChannelCount() != _channelCount) || (row.GetChannelDepth() != _channelDepth))
		row = GetImageRow();  
	png_read_row(_png_ptr, row.Get(), NULL);
}


