#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include <libshred.h>


static int DoDebug = 0;


#define TFTP_PORT			69
#define TFTP_TIMEOUT		800	// in ms.
#define TFTP_RESEND_LIMIT	8
#define TFTP_BLOCK_SIZE		512


typedef enum
{
	TFTP_PKT_TYPE_RRQ = 1,
	TFTP_PKT_TYPE_WRQ,
	TFTP_PKT_TYPE_DATA,
	TFTP_PKT_TYPE_ACK,
	TFTP_PKT_TYPE_ERROR
	
}TTFTPPacketTypes;


typedef enum
{
	TFTP_ERR_TYPE_UNKNOWN = 0,
	TFTP_ERR_TYPE_NOTFOUND,
	TFTP_ERR_TYPE_PERM,
	TFTP_ERR_TYPE_DISKFULL,
	TFTP_ERR_TYPE_ILLEGAL,
	TFTP_ERR_TYPE_BADTID,
	TFTP_ERR_TYPE_EXIST,
	TFTP_ERR_TYPE_USER
	
}TTFTPErrorTypes;


typedef struct
{
	int		Socket;
	uint32	dwIP;
	uint16	wPort;
	char	Host[32];
	char	Filename[256];
	FILE *	File;
	int		szFile;
	int		NextTickout;
	
	int		nBlockAbs;
	int		nBlock;
	int		nResend;
	
	int		nSentBytes;
	int		TickStart;
	
}TTFTPSession;


// it is ugly to use global variables
static TTFTPSession *Sessions = NULL;
static int szSessions = 0;


// follow TFTP byte order conventions
static void TFTP_Uint16ToBuf(uint16 u16, uint8 Buf[2])
{
	Buf[0] = (uint8)((u16 / 256) & 0x00ff);
	Buf[1] = (uint8)((u16 % 256) & 0x00ff);
}


static uint16 TFTP_BufToUint16(uint8 Buf[2], uint16 *pu16)
{
	uint16 u16;
	
	u16 = ((uint16)Buf[0]) * 256 + ((uint16)Buf[1]);
	if(pu16)
		*pu16 = u16;
	return u16;
}


static void TFTP_SendError(int Socket, uint32 dwIP, uint16 wPort, uint16 ErrorID)
{
	uint8 Buffer[1024];
	char *Msg = "Bullshit";
	int szWrite, szToWrite;
	
	TFTP_Uint16ToBuf(TFTP_PKT_TYPE_ERROR, &Buffer[0]);
	TFTP_Uint16ToBuf(ErrorID, &Buffer[2]);
	
	strcpy((char*)&Buffer[4], Msg);	// this could be fixed...
	
	szToWrite = strlen(Msg) + 1 + 4;
	szWrite = Net_SendTo(Socket, Buffer, szToWrite, dwIP, wPort);
	if(szWrite != szToWrite)
		LOG_DEBUG("warning: sendto truncated: %d / %d", szWrite, szToWrite);
}


static void _TFTP_SendBlock(TTFTPSession *pSession, uint8 *Buffer, int szData)
{
	int szWrite;
	
	if(DoDebug)
		LOG_DEBUG("sending last block %d", pSession->nBlockAbs);
	
	TFTP_Uint16ToBuf(TFTP_PKT_TYPE_DATA, &Buffer[0]);
	TFTP_Uint16ToBuf(pSession->nBlock, &Buffer[2]);
	
	szWrite = Net_SendTo(pSession->Socket, Buffer, 4 + szData, pSession->dwIP, pSession->wPort);
	if(szWrite != 4 + szData)
		LOG_DEBUG("warning: sendto truncated: %d / %d", szWrite, 4 + szData);
	
	pSession->nSentBytes += szWrite;
}


static int TFTP_SendBlock(TTFTPSession *pSession)
{
	uint8 Buffer[1024];
	int szRead;
	
	if(feof(pSession->File) || (pSession->nBlockAbs) * TFTP_BLOCK_SIZE >= pSession->szFile)
	{
		int Tick;
		
		_TFTP_SendBlock(pSession, Buffer, 0);
		
		Tick = Time_GetElapsedTick(pSession->TickStart);
		if(Tick <= 0)
			Tick = 1;
		if(DoDebug)
			LOG_DEBUG("file sent to %s: %d bytes, %d ms., %d KB/s", pSession->Host, pSession->nSentBytes, Tick, pSession->nSentBytes / Tick);
		return 0;
	}
	
	fseek(pSession->File, (pSession->nBlockAbs) * TFTP_BLOCK_SIZE, SEEK_SET);
	szRead = fread(&Buffer[4], 1, TFTP_BLOCK_SIZE, pSession->File);
	if(szRead <= 0)
	{
		LOG_ERROR("error: fread(): %d", System_GetLastErrorID());
		return -1;
	}
	
	_TFTP_SendBlock(pSession, Buffer, szRead);
	
	return 1;
}


static void TFTP_SessionClose(TTFTPSession *pSession)
{
	if(!pSession)
		return;
	
	// clse handles and free memory
	Net_CloseSocket(&pSession->Socket);
	fclose(pSession->File);
	
	// shift down and resize
	szSessions--;
	memmove(&pSession[0], &pSession[1], szSessions * sizeof(TTFTPSession) - (pSession - Sessions));
	Sessions = realloc(Sessions, szSessions * sizeof(TTFTPSession));
}


#ifndef WIN32
#include <signal.h>
#endif

static int _DoDebug = 0;
static int _DoExit = 0;

#ifndef WIN32
static void _Service_HandleSignalExit(int SignalNo)
{
	LOG_DEBUG("exiting due to signal %d", SignalNo);
	_DoExit = 1;
}
#endif

static void Service_Start(int DoDebug, char *LogFile)
{
	_DoDebug = DoDebug;
	
#ifndef WIN32
	if(!DoDebug)
	{
		if(LogFile && strcmp(LogFile, "/dev/null") != 0)
		{
			// log with rotation
			TLog Log;
			
			Log_Reset(&Log);
			Log_SetMode(&Log, LOG_AUTO_FLUSH | LOG_SHOW_DATE);
			Log_SetFile(&Log, stdout);
			Log_SetPadding(&Log, 25);
			Log.szOutMax = 16*1024;
			Log.nRotateFile = 2;
			Log.Filename = LogFile;
			
			Log_SetDefault(&Log);
			
			System_Daemonize();
			
			System_DupOutputToFile(LogFile, 1, 0);
		}
		else
		{
			// disable log
			TLog Log;
			
			Log_Reset(&Log);
			Log_SetFile(&Log, stdout);
			Log.Disabled = 1;
			Log_SetDefault(&Log);
			
			System_Daemonize();
		}
	}
	else
#endif
	{
		// log to stdout
		TLog Log;
		
		Log_Reset(&Log);
		Log_SetMode(&Log, LOG_AUTO_FLUSH | LOG_SHOW_DATE | LOG_SHOW_FILE | LOG_SHOW_LINE);
		Log_SetFile(&Log, stdout);
		Log_SetPadding(&Log, 50);
		Log_SetDefault(&Log);
	}
}


static void Service_PrepareExitCondition()
{
	if(_DoDebug)
	{
		Console_CatchCTRLCEvent();
		LOG_DEBUG("press ctrl+c to exit program");
	}
	else
	{
#ifndef WIN32
		signal(SIGINT, _Service_HandleSignalExit);
		signal(SIGTERM, _Service_HandleSignalExit);
		signal(SIGQUIT, _Service_HandleSignalExit);
		LOG_DEBUG("registered signals INT, TERM and QUIT as exit signals");
#endif
	}
}


static int Service_ExitRequested(void)
{
	if(_DoDebug)
	{
		if(Console_GetCTRLCEvent())
			return 1;
	}
	else
	{
		if(_DoExit)
			return 1;
	}
	return 0;
}


int main(int nArgs, char **Args)
{
	char *RootPath = NULL;
	int Port = TFTP_PORT;
	char *LogFile = NULL;
	TCliParserOption Options[] = {
		{"-r", "--root", "export files only from the specified subtree", "/path/to/tftproot", &RootPath, ARG_TYPE_STRING, 1},
		{"-p", "--port", "port to listen to", "69", &Port, ARG_TYPE_INT, 0},
		{"-l", "--log", "log file", "tftpd.log", &LogFile, ARG_TYPE_STRING, 0},
		{"-d", "--debug", "debug mode", "", &DoDebug, ARG_TYPE_FLAG, 0},
		{NULL,	NULL,	NULL, NULL, 0, 0}
	};
	int Socket = -1;
	uint8 Buffer[1024];
	
	
	if(!CliParser_DoOptions(Options, Args + 1, nArgs - 1))
		return 0;
	
#ifndef WIN32
	if(!System_IsRoot())
	{
		LOG_ERROR("error: must be root");
		return 0;
	}
#endif
	
	Service_Start(DoDebug, LogFile);
	
	Net_Init();
	
	LOG_DEBUG("root path: %s", RootPath);
	LOG_DEBUG("port:      %d", Port);
	
	//
	// prepare main socket
	//
	Socket = Net_OpenUDPSocket();
	if(Socket < 0)
	{
		LOG_ERROR("error: Net_OpenUDPSocket()");
		goto _end_;
	}
	
	if(!Net_BindSocketEx(Socket, INADDR_ANY, Net_ToNet16(Port), 0))
	{
		LOG_ERROR("error: Net_BindSocketEx()");
		goto _end_;
	}
	
	// always set non nBlock for UDP sockets
	Net_Option_SetBlock(Socket, 0);
	
	
	//
	// chdir
	//
	if(!FileSystem_FolderExists(RootPath))
	{
		LOG_ERROR("error: root path not found: %s", RootPath);
		goto _end_;
	}
	
#ifdef WIN32
// ugly fix
//#include <io.h>
_CRTIMP int __cdecl _chdir(const char*);
#define chdir(p)	_chdir(p)
#endif
	if(chdir(RootPath) != 0)
	{
		LOG_ERROR("error: chdir(): %d", System_GetLastErrorID());
		goto _end_;
	}
	
	Service_PrepareExitCondition();
	
	//
	// main loop
	//
	while(1)
	{
		TNetSelect NetSelect;
		int i, Result, Now;
		
		
		if(Service_ExitRequested())
			break;
		
		Net_Select_Reset(&NetSelect);
		
		// add main socket
		Net_Select_Add(&NetSelect, Socket, NET_OP_READ);
		// add session sockets
		for(i=0;i<szSessions;i++)
		{
			TTFTPSession *pSession = &Sessions[i];
			Net_Select_Add(&NetSelect, pSession->Socket, NET_OP_READ);
		}
		
		Result = Net_Select(&NetSelect, (szSessions > 0) ? 250 : 2000);	// speed up when we have a session
		if(Result < 0)
		{
			LOG_ERROR("error: Net_Select()");
			break;
		}
		
		if(Result == 0)
		{
			// handle timeout
			continue;
		}
		
		// main socket: connect-like
		Now = Time_GetTick();
		if(Net_Select_Check(&NetSelect, Socket))
		{
			char Host[32];
			uint32 dwIP;
			uint16 wPort;
			int lBuffer, PktType;
			char *Filename, *Mode;
			int lFilename;
			FILE *File;
			TTFTPSession *pSession;
			int szFile;
			
			
			lBuffer = Net_RecvFrom(Socket, Buffer, sizeof(Buffer), &dwIP, &wPort);
			if(lBuffer <= 0)
			{
				LOG_ERROR("error: Net_RecvFrom()");
				continue;
			}
			
			Net_AddrToString(dwIP, wPort, Host, sizeof(Host));
			
			if(lBuffer < 2)
			{
				LOG_ERROR("error: missing packet header");
				continue;
			}
			
			PktType = TFTP_BufToUint16(Buffer, NULL);
			//LOG_DEBUG("packet type: %d", PktType);
			
			if(PktType != TFTP_PKT_TYPE_RRQ)
			{
				LOG_DEBUG("warning: ignore invalid packet type: %d / %d", PktType, TFTP_PKT_TYPE_RRQ);
				continue;
			}
			
			Filename = (char*)&Buffer[2];
			lFilename = strlen(Filename);
			Mode = (char*)&Buffer[2 + lFilename + 1];
			
			if(DoDebug)
				LOG_DEBUG("new read request from %s on file %s, mode %s", Host, Filename, Mode);
			
			// security
			if(strstr(Filename, "/../"))
			{
				LOG_DEBUG("warning: trying to access file out of chroot: %s", Filename);
				TFTP_SendError(Socket, dwIP, wPort, TFTP_ERR_TYPE_NOTFOUND);
				continue;
			}
			
			if(!FileSystem_FileExists(Filename))
			{
				LOG_DEBUG("warning: file not found: %s", Filename);
				TFTP_SendError(Socket, dwIP, wPort, TFTP_ERR_TYPE_NOTFOUND);
				continue;
			}
			szFile = FileSystem_FileSize(Filename);
			
			File = fopen(Filename, "rb");
			if(!File)
			{
				LOG_DEBUG("error: opening file: %s", Filename);
				TFTP_SendError(Socket, dwIP, wPort, TFTP_ERR_TYPE_NOTFOUND);
				continue;
			}
			
			Sessions = realloc(Sessions, (szSessions + 1) * sizeof(TTFTPSession));
			pSession = &Sessions[szSessions++];
			
			memset(pSession, 0, sizeof(TTFTPSession));
			
			// fixme: check socket
			pSession->Socket = Net_OpenUDPSocket();
			
			// always set non nBlock for UDP sockets
			Net_Option_SetBlock(pSession->Socket, 0);
			pSession->dwIP = dwIP;
			pSession->wPort = wPort;
			snprintf(pSession->Host, sizeof(pSession->Host) - 1, "%s", Host);
			snprintf(pSession->Filename, sizeof(pSession->Filename) - 1, "%s", Filename);
			pSession->File = File;
			pSession->NextTickout = Now + TFTP_TIMEOUT;
			pSession->nBlockAbs = 0;
			pSession->nBlock = 1;
			pSession->nResend = 0;
			pSession->szFile = szFile;
			
			pSession->TickStart = Now;
			
			if(TFTP_SendBlock(pSession) <= 0)
				TFTP_SessionClose(pSession);
			
			continue;
		}
		
		// add session sockets
		Now = Time_GetTick();
		for(i=0;i<szSessions;i++)
		{
			TTFTPSession *pSession = &Sessions[i];
			
			
			if(Net_Select_Check(&NetSelect, pSession->Socket))
			{
				// we only accept ACKS
				uint32 dwIP;
				uint16 wPort;
				int lBuffer, PktType, BlockID;
				
				
				lBuffer = Net_RecvFrom(pSession->Socket, Buffer, sizeof(Buffer), &dwIP, &wPort);
				if(lBuffer <= 0)
				{
					LOG_ERROR("error: Net_RecvFrom()");
					continue;
				}
				
				if(pSession->dwIP != dwIP || pSession->wPort != wPort)
				{
					LOG_ERROR("error: host mismatch");
					continue;
				}
				
				if(lBuffer < 2)
				{
					LOG_ERROR("error: missing packet header");
					continue;
				}
				
				PktType = TFTP_BufToUint16(Buffer, NULL);
				//LOG_DEBUG("packet type: %d", PktType);
				
				if(PktType != TFTP_PKT_TYPE_ACK)
				{
					LOG_DEBUG("warning: ignore invalid packet type: %d / %d", PktType, TFTP_PKT_TYPE_ACK);
					continue;
				}
				
				BlockID = TFTP_BufToUint16(Buffer + 2, NULL);
				//LOG_DEBUG("requested nBlock id: %d", BlockID);
				
				pSession->nBlock = BlockID + 1;
				if(pSession->nBlock > 65535)
					pSession->nBlock = 0;
				
				pSession->nBlockAbs++;
				pSession->NextTickout = Now + TFTP_TIMEOUT;
				pSession->nResend = 0;
				
				if(TFTP_SendBlock(pSession) <= 0)
				{
					TFTP_SessionClose(pSession);
					i--;
				}
			}
			else
			{
				if(Time_Diff32(Now, pSession->NextTickout) > TFTP_TIMEOUT)
				{
					if(pSession->nResend > TFTP_RESEND_LIMIT)
					{
						LOG_DEBUG("warning: too many timeouts, closing connection to %s", pSession->Host);
						TFTP_SendError(Socket, pSession->dwIP, pSession->wPort, TFTP_ERR_TYPE_UNKNOWN);
						TFTP_SessionClose(pSession);
						i--;
						continue;
					}
					
					pSession->NextTickout = Now + TFTP_TIMEOUT;
					pSession->nResend++;
					
					if(TFTP_SendBlock(pSession) <= 0)
					{
						TFTP_SessionClose(pSession);
						i--;
					}
				}
			}
		}
	}
	
_end_:
	{
		int i;
		for(i=0;i<szSessions;i++)
		{
			TTFTPSession *pSession = &Sessions[i];
			TFTP_SessionClose(pSession);
			i--;
		}
	}
	Net_CloseSocket(&Socket);
	
	Net_Cleanup();
	
	return 0;
}
