#include <stdio.h>
#include <windows.h>
#include <assert.h>
#include "sock.h"
#include "heap.h"
#define MAX_SIZE 8192

#define PEB_SPACE        0x7ffdf154
#define PEB_LOCK_ROUTINE 0x7ffdf020

void Usage()
{
	printf("Usage: vulnprog listen_port\n");
	exit(-1);
}

DWORD connection_count = 0;

void ProcessNewPacket(char *buf, int buf_size)
{
}

void DumpBuffer(char *varname, char *buf, int len)
{
	BYTE *p;
	int i, j;
	if (!varname || !buf || !len) return;

	printf("char %s[] = // %u bytes\n", varname, len);
	printf("\t\"");

	for (i = 0, p = buf; i < len; i += sizeof(*p), p++)
	{
		printf("\\x%02x", *p);

		j = i + sizeof(*p);
		if (j < len && !(j % 16)) printf("\"\n\t\"");
	}
	printf("\";\n");
}

int main(int argc, char **argv)
{
	SOCK socks[MAX_SOCKETS];
	SOCKADDR_IN dst_addr, listen_addr;
	int src_port;
	WSADATA wsadata;
	SOCKADDR_IN client_addr;
	int do_abort = 0, client_count = 0;
	int i, new_client, bytes_to_read, bytes_read;
	fd_set readfds;
	HEAP *pHeap;

	pHeap = HeapCreate(2, 0x10000, 0);
	printf("vulnprog heap base is 0x%08lx\n", pHeap);
	pHeap->ForceFlags = HEAP_GROWABLE;
	FD_ZERO(&readfds);

	if (argc != 2) Usage();
	src_port = atoi(argv[1]);
	if (!src_port) Usage();

	WSAStartup(MAKEWORD(1, 1,), &wsadata);
	memset(&dst_addr, 0, sizeof(dst_addr));
	memset(&listen_addr, 0, sizeof(listen_addr));
	sock_new_sockets(socks, MAX_SOCKETS);
	listen_addr.sin_addr.s_addr = INADDR_ANY;
	listen_addr.sin_family = AF_INET;
	listen_addr.sin_port = htons((unsigned short)src_port);

	i = sock_get_nonblocking_listening_socket(socks, MAX_SOCKETS, 0);
	if (i == SOCKET_ERROR)
	{
		printf("Error opening listening socket: error code 0x%08lx\n", GetLastError());
		return -1;
	}

	printf("Binding to port %d\n", src_port); fflush(stdout);
	if (bind(socks[i].socket, (SOCKADDR *)&listen_addr, sizeof(listen_addr)) == SOCKET_ERROR)
	{
		printf("Error with bind: error code 0x%08lx\n", GetLastError());
		closesocket(socks[i].socket);
		WSACleanup();
		return -1;
	}
	printf("Listening on port %d\n", src_port);
	if (listen(socks[i].socket, 0) == SOCKET_ERROR)
	{
		printf("Error with listen: error code 0x%08lx\n", GetLastError());
		closesocket(socks[i].socket);
		WSACleanup();
		return -1;
	}

	while (1)
	{
		Sleep(500);
		for (i = 0; i < MAX_SOCKETS; i++)
		{
			if (socks[i].socket == INVALID_SOCKET) continue;
			assert(socks[i].is_connected || socks[i].is_listening);

			if (socks[i].is_listening) // Handle new connection
			{
				new_client = sock_accept_nonblocking_client(socks, MAX_SOCKETS, socks[i].socket, &client_addr);
				if (new_client == MAX_SOCKETS) 
				{
					printf("Error: All sockets are used up, igoring client from %s:%d\n", inet_ntoa(client_addr.sin_addr), ntohs(client_addr.sin_port));
					continue;
				}
				else if (new_client == INVALID_SOCKET)
				{
					if (GetLastSocketError() != EWOULDBLOCK && !client_count) return -1;
				}
				else
				{
					printf("\n[Client %d.%d] Accepted from %s:%d\n", socks[new_client].id, socks[new_client].socket, inet_ntoa(client_addr.sin_addr), ntohs(client_addr.sin_port));
					client_count++;
					socks[new_client].is_connected = 1;
					socks[new_client].id = connection_count++;
				}
			}

			else if (socks[i].is_connected)
			{
				bytes_to_read = 0;
				printf("[Client %d.%d] Checking for data\n", socks[i].id, socks[i].socket);

				if (!socks[i].buf_size)
				{
					bytes_read = sock_recv(&socks[i], (char *)&bytes_to_read, sizeof(bytes_to_read), 1);
					if (bytes_read <= 0 && GetLastSocketError() == EWOULDBLOCK) continue;
					if (bytes_read <= 0)
					{
						printf("[Client %d.%d] Client disconnected\n", socks[i].id, socks[i].socket);
						do_abort = 1;
						goto abort;
					}

					printf("[Client %d.%d] Allocating %d bytes at client's request\n", socks[i].id, socks[i].socket, bytes_to_read); fflush(stdout);
					socks[i].buf_size = bytes_to_read;
					socks[i].buf = HeapAlloc(pHeap, 0, socks[i].buf_size);
					if (!socks[i].buf)
					{
						printf("Error allocating %d bytes (error code 0x%08lx)\n", socks[i].buf_size, GetLastError());
						do_abort = 1;
						goto abort;
					}
					printf("\tBuffer allocated at 0x%08lx\n", socks[i].buf);
					memset(socks[i].buf, 0, socks[i].buf_size);
				}
				else
				{
					// Data waiting to be read
					bytes_read = sock_recv(&socks[i], socks[i].buf, MAX_SIZE, 0);
					if (bytes_read <= 0 && GetLastSocketError() == EWOULDBLOCK) continue;
					printf("[Client %d.%d] Received %d bytes into buffer at 0x%08lx\n", socks[i].id, socks[i].socket, bytes_read, socks[i].buf);
					if (bytes_read <= 0)
					{
						printf("[Client %d.%d] Client disconnected\n", socks[i].id, socks[i].socket);
						do_abort = 1;
						goto abort;
					}
					ProcessNewPacket(socks[i].buf, socks[i].buf_size);
				}
				assert(!do_abort);				
abort:
				if (do_abort)
				{
					if (socks[i].buf)
					{
						printf("[Client %d.%d] Freeing client buffer at 0x%08lx (%d bytes)\n", socks[i].id, socks[i].socket, socks[i].buf, socks[i].buf_size); fflush(stdout);
						HeapFree(pHeap, 0, socks[i].buf); socks[i].buf = NULL; socks[i].buf_size = 0;
						//DumpFreeLists(stdout, pHeap, FALSE);
						//DumpLookasideLists(stdout, pHeap, FALSE);
					}
					sock_close_socket(&socks[i]);
					assert(!socks[i].buf && !socks[i].buf_size);
					do_abort = 0;
				}
			}
		}
	}

	sock_close_sockets(socks, MAX_SOCKETS);
	WSACleanup();
	return -1; // this shouldn't be reached unless there was a problem
}
