// Simple driver that demonstrates dynamically loading and unloading

#include "ntddk.h"
#include "ocamldrv.h"

#define NT_DEVICE_NAME      L"\\Device\\OcamlDrv"
#define DOS_DEVICE_NAME     L"\\DosDevices\\OCAMLDRV"

static KSEMAPHORE InsertSemaphore;
static KSEMAPHORE IsFullSemaphore;

int __cdecl caml_main(
    IN PVOID      IoBuffer
    );

int __cdecl interpreter(
    IN int      code_size
    );

void tlsIntialize(void);

NTSTATUS
OcamlDrvDispatch(
    IN PDEVICE_OBJECT DeviceObject,
    IN PIRP Irp
    );

VOID
OcamlDrvUnload(
    IN PDRIVER_OBJECT DriverObject
    );

NTSTATUS
OcamlDrvByteCodeFeed(
    IN OUT PVOID      IoBuffer
    );

NTSTATUS
OcamlDrvByteCodeInterpreter(
    IN OUT PVOID      IoBuffer
    );

NTSTATUS
OcamlDrvByteCodeAllocate(
    IN OUT PVOID      IoBuffer
    );

VOID
OcamlDrvCleanup(
	IN OUT PVOID      IoBuffer
	);

NTSTATUS
DriverEntry(
    IN PDRIVER_OBJECT DriverObject,
    IN PUNICODE_STRING RegistryPath
    )
{

    PDEVICE_OBJECT deviceObject = NULL;
    NTSTATUS status;
    UNICODE_STRING uniNtNameString;
    UNICODE_STRING uniWin32NameString;

    DbgPrint("OCAMLDRV.SYS: Entered the Ocaml driver!\n");


    // Create counted string version of our device name.
    RtlInitUnicodeString( &uniNtNameString, NT_DEVICE_NAME );

    //
    // Create the device object
    //

    status = IoCreateDevice(
                 DriverObject,
                 0,                     // We don't use a device extension
                 &uniNtNameString,
                 FILE_DEVICE_UNKNOWN,
                 0,                     // No standard device characteristics
                 FALSE,                 // This isn't an exclusive device
                 &deviceObject
                 );

    if ( NT_SUCCESS(status) )
    {

        //
        // Create dispatch points for create/open, close, unload.
        //

        DriverObject->MajorFunction[IRP_MJ_CREATE]         =
        DriverObject->MajorFunction[IRP_MJ_CLOSE]          =
        DriverObject->MajorFunction[IRP_MJ_DEVICE_CONTROL] = OcamlDrvDispatch;
        DriverObject->DriverUnload = OcamlDrvUnload;

        DbgPrint("OCAMLDRV.SYS: just about ready!\n");

        //
        // Create counted string version of our Win32 device name.
        //
    
        RtlInitUnicodeString( &uniWin32NameString, DOS_DEVICE_NAME );
    
        //
        // Create a link from our device name to a name in the Win32 namespace.
        //
        
        status = IoCreateSymbolicLink( &uniWin32NameString, &uniNtNameString );

        if (!NT_SUCCESS(status))
        {
            DbgPrint("OCAMLDRV.SYS: Couldn't create the symbolic link\n");

            IoDeleteDevice( DriverObject->DeviceObject );
        }
        else
        {
			int i;

			// initialize tls struct
			tlsIntialize();
			
	     }
    }
    else
    {
        DbgPrint("OCAMLDRV.SYS: Couldn't create the device\n");
    }
    return status;
}


NTSTATUS
OcamlDrvDispatch(
    IN PDEVICE_OBJECT DeviceObject,
    IN PIRP Irp
    )
{
    PIO_STACK_LOCATION irpStack;
    PVOID ioBuffer;
    ULONG ioControlCode;
    NTSTATUS ntStatus;
	
    //
    // Init to default settings- we only expect 1 type of
    //     IOCTL to roll through here, all others an error.
    //
	
    Irp->IoStatus.Status = STATUS_SUCCESS;
    Irp->IoStatus.Information = 0;
	
	
    //
    // Get the pointer to the input/output buffer and it's length
    //
	
    ioBuffer = Irp->AssociatedIrp.SystemBuffer;
	
    //
    // Get a pointer to the current location in the Irp. This is where
    //     the function codes and parameters are located.
    //

    irpStack = IoGetCurrentIrpStackLocation(Irp);

	
    switch (irpStack->MajorFunction)
    {
    case IRP_MJ_CREATE:
		
		DbgPrint("OCAMLDRV.SYS: IRP_MJ_CREATE\n");
		
		break;
		
		
    case IRP_MJ_CLOSE:
		
		DbgPrint("OCAMLDRV.SYS: IRP_MJ_CLOSE\n");

		break;
		
		
    case IRP_MJ_DEVICE_CONTROL:
		
//		DbgPrint("OCAMLDRV.SYS: IRP_MJ_DEVICE_CONTROL\n");

        ioControlCode = irpStack->Parameters.DeviceIoControl.IoControlCode;
		
        switch (ioControlCode)
        {

        case IOCTL_OCAML_BYTECODE_SIZE:
			
//	        DbgPrint("OCAMLDRV.SYS: IOCTL_OCAML_BYTECODE_SIZE\n");

            Irp->IoStatus.Status = OcamlDrvByteCodeAllocate (ioBuffer);

            if (NT_SUCCESS(Irp->IoStatus.Status))
            {
                // Success! Set the following to sizeof(PVOID) to
                //     indicate we're passing valid data back.
                Irp->IoStatus.Information = sizeof(PVOID);
#ifdef DBG				
				DbgPrint("OCAMLDRV.SYS: OcamlDrvByteCodeAllocate() successful.\n");
#endif
            }
			
            break;

        case IOCTL_OCAML_BYTECODE_FEED:
			
//	        DbgPrint("OCAMLDRV.SYS: IOCTL_OCAML_BYTECODE_FEED\n");

            Irp->IoStatus.Status = OcamlDrvByteCodeFeed (ioBuffer);
			
            if (NT_SUCCESS(Irp->IoStatus.Status))
            {
                // Success! Set the following to sizeof(PVOID) to
                //     indicate we're passing valid data back.
                Irp->IoStatus.Information = sizeof(PVOID);
				
#ifdef DBG
//				DbgPrint("OCAMLDRV.SYS: %i bytes tranferred successfully.\n",
//					((POCAML_BYTECODE_INFO)ioBuffer)->ByteCodeOffset + 
//					((POCAML_BYTECODE_INFO)ioBuffer)->ByteCodeSize);
#endif
            }
			
            else
            {
                switch (Irp->IoStatus.Status)
				{
				case STATUS_NOT_OCAML_BYTECODE:
				
	                DbgPrint("OCAMLDRV.SYS: it is NOT Ocaml-1.07 bytecode!\n");
					break;

				default:
					
					Irp->IoStatus.Status = STATUS_INVALID_PARAMETER;
					DbgPrint("OCAMLDRV.SYS: bytecode interpreter failed!\n");
				}
            }
			
            break;
				
        case IOCTL_OCAML_BYTECODE_INTERPRETE:
			
	        DbgPrint("OCAMLDRV.SYS: IOCTL_OCAML_BYTECODE_INTERPRETE\n");

            Irp->IoStatus.Status = OcamlDrvByteCodeInterpreter (ioBuffer);
			
			break;

        case IOCTL_OCAML_BYTECODE_CLEAN:
			
			OcamlDrvCleanup(ioBuffer);

			break;

        default:
			
            DbgPrint("OCAMLDRV.SYS: unknown IRP_MJ_DEVICE_CONTROL\n");
			
            Irp->IoStatus.Status = STATUS_INVALID_PARAMETER;
			
            break;
			
        }
		
    }
	
	
    //
    // DON'T get cute and try to use the status field of
    // the irp in the return status.  That IRP IS GONE as
    // soon as you call IoCompleteRequest.
    //
	
    ntStatus = Irp->IoStatus.Status;
	
    IoCompleteRequest(Irp, IO_NO_INCREMENT);
	
	
    //
    // We never have pending operation so always return the status code.
    //
	
    return ntStatus;
}

VOID
OcamlDrvUnload(
    IN PDRIVER_OBJECT DriverObject
    )
{
    UNICODE_STRING uniWin32NameString;

    //
    // All *THIS* driver needs to do is to delete the device object and the
    // symbolic link between our device name and the Win32 visible name.
    //
    // Almost every other driver ever witten would need to do a
    // significant amount of work here deallocating stuff.
    //

    DbgPrint("OCAMLDRV.SYS: Unloading!!\n");
    
    //
    // Create counted string version of our Win32 device name.
    //

    RtlInitUnicodeString( &uniWin32NameString, DOS_DEVICE_NAME );

    //
    // Delete the link from our device name to a name in the Win32 namespace.
    //
    
    IoDeleteSymbolicLink( &uniWin32NameString );

    //
    // Finally delete our device object
    //

    IoDeleteDevice( DriverObject->DeviceObject );
}


NTSTATUS
OcamlDrvByteCodeAllocate(
    IN OUT PVOID      IoBuffer
    )
{
    POCAML_BYTECODE_INFO pobi = (POCAML_BYTECODE_INFO) IoBuffer;
	NTSTATUS ntStatus = STATUS_SUCCESS;

	// allocate bytecode
	if (!(pobi->KByteCode = (PUCHAR) ExAllocatePool(NonPagedPool, pobi->ByteCodeSize + 1)))
	    return STATUS_INSUFFICIENT_RESOURCES;

#ifdef DBG
	DbgPrint("OCAMLDRV.SYS: pobi->KByteCode = 0x%x\n", pobi->KByteCode);
#endif


	*((PVOID *) IoBuffer) = pobi->KByteCode;


	// terminate with 0 for testing
	pobi->KByteCode[pobi->ByteCodeSize] = 0;

	return ntStatus;
}


NTSTATUS
OcamlDrvByteCodeFeed(
    IN OUT PVOID      IoBuffer
    )
{

	//
	// extract bytecode and execute.
	//
    POCAML_BYTECODE_INFO pobi = (POCAML_BYTECODE_INFO) IoBuffer;
	int status;

	memcpy (pobi->KByteCode + pobi->ByteCodeOffset,
		pobi->ByteCode, pobi->ByteCodeSize);

	return STATUS_SUCCESS;
}


NTSTATUS
OcamlDrvByteCodeInterpreter(
    IN OUT PVOID      IoBuffer
    )
{
	//
	// prepare heap memory.
	//

    POCAML_BYTECODE_INFO pobi = (POCAML_BYTECODE_INFO) IoBuffer;
	int code_size, i;
	int threadId;

	KeWaitForSingleObject(&IsFullSemaphore,
		Executive,
		KernelMode,
		1,
		NULL);
	KeWaitForSingleObject(&InsertSemaphore,
		Executive,
		KernelMode,
		1,
		NULL);
	threadId = insertThread((PVOID)PsGetCurrentThread());
	KeReleaseSemaphore(&InsertSemaphore, NULL, 1, 0);

	//
	// prepare environment and then execute.
	//
	tls[threadId].loc = 0;

	if ((code_size = caml_main(IoBuffer)) == 0) return 1;

	// copy heap address
	tls[threadId].mem_alloc[0] = tls[threadId].mem_alloc[2];

	// initialize memory array 
	for (i = 1 ; i < MAXMEM - 1; i++)
		tls[threadId].mem_alloc[i] = 0;
	
	tls[threadId].loc = 1;

	interpreter(code_size);

	for (i = 1; i < MAXMEM; i++)
	{
		if (tls[threadId].mem_alloc[i] == 0) continue;
#ifdef DBG
		DbgPrint ("ExFreePool(%x), threadId = %i\n", tls[threadId].mem_alloc[i] - 2, threadId);
#endif
		ExFreePool(tls[threadId].mem_alloc[i] - 2);
	}
	
	// copy to the last mem address for next cycle cleanup
	tls[threadId].mem_alloc[MAXMEM - 1] = tls[threadId].mem_alloc[0];

	DbgPrint("OCAMLDRV.SYS: ThreadId = %i terminated.\n", removeThread((PVOID)PsGetCurrentThread()));
	KeReleaseSemaphore(&IsFullSemaphore, NULL, 1, 0);

	return STATUS_SUCCESS;
}


VOID
OcamlDrvCleanup(
	IN OUT PVOID       IoBuffer
	)
{

	//
	// cleanup memory allocate bytecode.
	//
    POCAML_BYTECODE_INFO pobi = (POCAML_BYTECODE_INFO) IoBuffer;

	ExFreePool(pobi->KByteCode);

#ifdef DBG
	DbgPrint("OCAMLDRV.SYS: ExFreePool(pobi->KByteCode) success.\n");
#endif

	return;
}

void tlsIntialize(void)
{
	int threadId;
	opcode_t callback1_code[] = {ACC1, APPLY1, POP, 1, STOP};
	opcode_t callback2_code[] = {ACC2, APPLY2, POP, 1, STOP};
	opcode_t callback3_code[] = {ACC3, APPLY3, POP, 1, STOP};
	struct named_value * named_value_table[Named_value_size] = { NULL, };
	struct {
		value filler1;
		header_t h;
		value first_bp;
		value filler2;
	} sentinel = {0, Make_header (0, 0, Blue), 0, 0};
	unsigned char printable_chars_ascii[] = /* 0x20-0x7E */
  "\000\000\000\000\377\377\377\377\377\377\377\377\377\377\377\177\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000";
	unsigned char printable_chars_iso[] = /* 0x20-0x7E 0xA1-0xFF */
  "\000\000\000\000\377\377\377\377\377\377\377\377\377\377\377\177\000\000\000\000\376\377\377\377\377\377\377\377\377\377\377\377";

	// initialize memory array 
	for (threadId = 0 ; threadId < MAXTHREAD; threadId++)
	{
		tls[threadId].mem_alloc[MAXMEM - 1] = 0;

		tls[threadId].callback_depth = 0;
		memcpy(tls[threadId].callback1_code, callback1_code, sizeof(tls[threadId].callback1_code));
		memcpy(tls[threadId].callback2_code, callback2_code, sizeof(tls[threadId].callback2_code));
		memcpy(tls[threadId].callback3_code, callback3_code, sizeof(tls[threadId].callback3_code));
		tls[threadId].callback_code_threaded = 0;
		memcpy(tls[threadId].named_value_table, named_value_table, sizeof(named_value_table));
		tls[threadId].initial_ofs = 1;
		tls[threadId].extern_table = NULL;

		tls[threadId].dbg_socket = -1;

		memcpy(&tls[threadId].sentinel, &sentinel, sizeof(sentinel));
		tls[threadId].fl_prev = NULL;
		tls[threadId].fl_last = NULL;
		tls[threadId].fl_merge = NULL;
		tls[threadId].fl_cur_size = 0;
		
		tls[threadId].stat_minor_words = 0;
		tls[threadId].stat_promoted_words = 0;
		tls[threadId].stat_major_words = 0;
		tls[threadId].stat_minor_collections = 0;
		tls[threadId].stat_major_collections = 0;
		tls[threadId].stat_heap_size = 0;           /* bytes */
		tls[threadId].stat_compactions = 0;
		
		tls[threadId].young_start = NULL;
		tls[threadId].young_end = NULL;
		tls[threadId].young_ptr = NULL;
		tls[threadId].young_limit = NULL;
		tls[threadId].ref_table = NULL;
		tls[threadId].ref_table_ptr = NULL;
		tls[threadId].in_minor_collection = 0;

		tls[threadId].seed = 0x12345;

		tls[threadId].async_signal_mode = 0;
		tls[threadId].pending_signal = 0;
		tls[threadId].something_to_do = 0;
		tls[threadId].force_major_slice = 0;
		tls[threadId].signal_handlers = 0;
		
		tls[threadId].verbose_init = 0;
		
		tls[threadId].icount = 0;
		tls[threadId].trace_flag = 0;

		memcpy(tls[threadId].printable_chars_ascii, printable_chars_ascii, sizeof(printable_chars_ascii));
		memcpy(tls[threadId].printable_chars_iso, printable_chars_iso, sizeof(printable_chars_iso));
		
		tls[threadId].local_roots = NULL;
		tls[threadId].global_roots = NULL;
		
		tls[threadId].weak_list_head = 0;

		peThread[threadId] = 0;
	}
	KeInitializeSemaphore(&InsertSemaphore, 1, 1);
	KeInitializeSemaphore(&IsFullSemaphore, MAXTHREAD, MAXTHREAD);
	
	DbgPrint("OCAMLDRV.SYS: All tls initialized!\n");
}
