// Written by Matt Conover in December 2004
#include <assert.h>
#include "sock.h"

int id = 0;
void sock_new_sockets(SOCK socks[], unsigned int socks_size)
{
	int i;
	for (i = 0; i < (int)socks_size; i++)
	{
		memset(&socks[i], 0, sizeof(SOCK));
		socks[i].socket = INVALID_SOCKET;
		socks[i].id = -1;
	}
}

int sock_get_socket(SOCK socks[], unsigned int socks_size, int use_udp)
{
	int i;
	for (i = 0; i < (int)socks_size && socks[i].socket != INVALID_SOCKET; i++); // find end
	if (i == (int)socks_size) return socks_size; // no free socks

	socks[i].is_udp = use_udp;
	if (use_udp) socks[i].socket = socket(AF_INET, SOCK_DGRAM, 0);
	else socks[i].socket = socket(AF_INET, SOCK_STREAM, 0);
	
	if (socks[i].socket == INVALID_SOCKET)
	{
		return INVALID_SOCKET;
	}
	else
	{
		socks[i].id = id++;
		return i;
	}
}

int sock_convert_to_nonblocking_socket(SOCK *sock)
{
	unsigned long value = 1;
	if (ioctlsocket(sock->socket, FIONBIO, &value) == SOCKET_ERROR)
	{
		fprintf(stderr, "ioctlsocket failed: error code 0x%08lx\n", GetLastError());
		return INVALID_SOCKET;
	}
	return 0;
}

int sock_get_nonblocking_socket(SOCK socks[], unsigned int socks_size, int use_udp)
{
	unsigned long value = 1;
	int i = sock_get_socket(socks, socks_size, use_udp);
	if (i == INVALID_SOCKET || i == (int)socks_size) return i;
	if (ioctlsocket(socks[i].socket, FIONBIO, &value) == SOCKET_ERROR)
	{
		fprintf(stderr, "ioctlsocket failed: error code 0x%08lx\n", GetLastError());
		return INVALID_SOCKET;
	}
	return i;
}
	
int sock_get_listening_socket(SOCK socks[], unsigned int socks_size, int use_udp)
{
	int i = sock_get_socket(socks, socks_size, use_udp);
	if (i == (int)socks_size || i == INVALID_SOCKET) return i;
	socks[i].is_listening = 1;
	return i;
}

int sock_get_nonblocking_listening_socket(SOCK socks[], unsigned int socks_size, int use_udp)
{
	int i = sock_get_nonblocking_socket(socks, socks_size, use_udp);
	if (i == (int)socks_size || i == INVALID_SOCKET) return i;
	socks[i].is_listening = 1;
	return i;
}

int sock_get_connecting_socket(SOCK socks[], unsigned int socks_size)
{
	int i = sock_get_socket(socks, socks_size, 0);
	if (i == (int)socks_size || i == INVALID_SOCKET) return i;
	socks[i].is_connecting = 1;
	return i;
}

int sock_get_nonblocking_connecting_socket(SOCK socks[], unsigned int socks_size)
{
	int i = sock_get_nonblocking_socket(socks, socks_size, 0);
	if (i == (int)socks_size || i == INVALID_SOCKET) return i;
	socks[i].is_connecting = 1;
	return i;
}

int sock_connect(SOCK *sock, SOCKADDR_IN *dst_addr)
{
	if (connect(sock->socket, (SOCKADDR *)dst_addr, sizeof(SOCKADDR)) == SOCKET_ERROR)
	{
		fprintf(stderr, "Error connecting to %s:%d: error code 0x%08lx\n", inet_ntoa(dst_addr->sin_addr), ntohs(dst_addr->sin_port), GetLastError());
		return SOCKET_ERROR;
	}

	sock->dst_addr = dst_addr->sin_addr.s_addr;
	sock->dst_port = ntohs(dst_addr->sin_port);
	return 0;
}

int sock_accept_client(SOCK socks[], unsigned int socks_size, SOCKET listen_sock, SOCKADDR_IN *client_addr)
{
	int i, socksize = sizeof(SOCKADDR_IN);
	for (i = 0; i < (int)socks_size && socks[i].socket != INVALID_SOCKET; i++); // find end
	if (i == (int)socks_size) return socks_size; // no free socks

	memset(socks+i, 0, sizeof(SOCK));
	socks[i].socket = accept(listen_sock, (SOCKADDR *)client_addr, &socksize);
	if (socks[i].socket == INVALID_SOCKET) 
	{
		fprintf(stderr, "Error with accept: error code 0x%08lx\n", GetLastError());
		return INVALID_SOCKET;
	}
	socks[i].id = id++;
	socks[i].is_listening = 0;
	socks[i].is_connected = 1;
	socks[i].dst_addr = client_addr->sin_addr.s_addr;
	socks[i].dst_port = ntohs(client_addr->sin_port);
	return i;
}

int sock_accept_nonblocking_client(SOCK socks[], unsigned int socks_size, SOCKET listen_sock, SOCKADDR_IN *client_addr)
{
	unsigned long value = 1;
	int i, socksize = sizeof(SOCKADDR_IN);
	for (i = 0; i < (int)socks_size && socks[i].socket != INVALID_SOCKET; i++); // find end
	if (i == (int)socks_size) return socks_size; // no free socks

	memset(socks+i, 0, sizeof(SOCK));
	socks[i].socket = accept(listen_sock, (SOCKADDR *)client_addr, &socksize);
	if (socks[i].socket == INVALID_SOCKET) 
	{
		if (GetLastSocketError() == EWOULDBLOCK) return INVALID_SOCKET;
		fprintf(stderr, "Error with accept: error code 0x%08lx\n", GetLastError());
		return INVALID_SOCKET;
	}
	if (ioctlsocket(socks[i].socket, FIONBIO, &value) == SOCKET_ERROR)
	{
		fprintf(stderr, "ioctlsocket failed: error code 0x%08lx\n", GetLastError());
		return INVALID_SOCKET;
	}
	socks[i].is_connected = 1;
	socks[i].dst_addr = client_addr->sin_addr.s_addr;
	socks[i].dst_port = ntohs(client_addr->sin_port);
	return i;
}

SOCKET sock_get_max_socket(SOCK socks[], unsigned int socks_size)
{
	int i, max_index = -1;
	for (i = 0; i < (int)socks_size; i++)
	{
		if (socks[i].socket == INVALID_SOCKET) continue;

		if (max_index < 0)
		{
			max_index = i;
			continue;
		}
		if (socks[max_index].socket < socks[i].socket) max_index = i;
	}

	if (max_index < 0) return INVALID_SOCKET;
	else return socks[max_index].socket;
}

int sock_recv(SOCK *sock, char *buf, unsigned int buf_size, int recv_exact)
{
	char *p = buf;
	int bytes_left, bytes_read;

	for (p = buf, bytes_left = buf_size; bytes_left != 0; p += bytes_read, bytes_left -= bytes_read)
	{
		bytes_read = recv(sock->socket, p, bytes_left, 0);
		if (bytes_read <= 0)
		{
			if (GetLastSocketError() == EWOULDBLOCK) return 0;
			printf("[Client %d.%d] Error with recv: read %d bytes (error code 0x%08lx)\n", sock->id, sock->socket, bytes_read, GetLastError());
			return -1;
		}

		printf("[Client %d.%d] Read %d bytes\n", sock->id, sock->socket, bytes_read);
		if (!recv_exact) return bytes_read;
	}

	return buf_size;
}

int sock_send(SOCK *sock, char *buf, unsigned int buf_size)
{
	char *p;
	int bytes_left, bytes_sent;

	for (p = buf, bytes_left = buf_size; bytes_left > 0; p += bytes_sent, bytes_left -= bytes_sent)
	{
		bytes_sent = send(sock->socket, p, bytes_left, 0);
		if (bytes_sent <= 0)
		{
			printf("[Client %d.%d] Error with send: sent only %d of %d bytes (error code 0x%08lx)\n", sock->id, sock->socket, bytes_sent, bytes_left, GetLastError());
			return -1;
		}
		
		printf("[Client %d.%d] Sent %d of %d bytes\n", sock->id, sock->socket, bytes_sent, buf_size);
	}

	return buf_size;
}

void sock_close_socket(SOCK *sock)
{
	if (sock->socket == INVALID_SOCKET) return;

	closesocket(sock->socket);
	if (sock->forward_socket)
	{
		assert(sock->forward_socket->forward_socket == sock);
		sock->forward_socket->forward_socket = NULL;
		sock_close_socket(sock->forward_socket);
	}
	memset(sock, 0, sizeof(SOCK));
	sock->socket = INVALID_SOCKET;
	sock->id = -1;
}

void sock_close_sockets(SOCK socks[], unsigned int socks_size)
{
	int i;
	for (i = 0; i < (int)socks_size; i++) sock_close_socket(socks+i);
}
