#include "list.h"

#include <stdlib.h>
#include <pthread.h>

typedef struct list_node list_node_t ;
struct list_node {
	void * element ;
	list_node_t * next ;
} ;

struct list {
	list_node_t * first ;
	list_node_t * last ;
	size_t count ;
	pthread_rwlock_t read_write_lock ;
} ;

struct list_iterator {
	list_t * list ;
	list_node_t * previous ;   /* provides the ability to remove the current node and link the previous one */
	list_node_t * current ;
	list_node_t * next ;       /* provides the ability to remove the current node and move next */
	list_iterator_mode_t mode ;
} ;

list_t * list_new (void) {
	list_t * list = malloc (sizeof (list_t)) ;
	if (list != NULL) {
		list->first = NULL ;
		list->last = NULL ;
		list->count = 0 ;
		if (pthread_rwlock_init (& list->read_write_lock, NULL)) {
			free (list) ;
			return NULL ;
		}
	}
	return list ;
}

int list_free (list_t * list) {
	int return_code = 0 ;
	return_code |= list_remove_all (list) ;
	return_code |= pthread_rwlock_destroy (& list->read_write_lock) ;
	free (list) ;
	return return_code ;
}

int list_free_all (list_t * list, void (* free_element) (void *)) {
	int return_code = 0 ;
	list_iterator_t * iterator ;
	void * element ;

	iterator = list_iterator_new (list, LIST_WRITE) ;
	if (iterator != NULL) {
		while (list_iterator_next (iterator) == 0) {
			element = list_iterator_get (iterator) ;
			return_code |= list_iterator_remove (iterator) ;
			if (free_element == NULL)
				free (element) ;
			else
				free_element (element) ;
		}
		return_code |= list_iterator_free (iterator) ;
	}

	return_code |= list_free (list) ;

	return return_code ;
}

static int list_compare (const void * element1, const void * element2, list_compare_f compare) {
	if (compare == NULL) {
		if (element1 < element2)
			return -1 ;
		else if (element1 == element2)
			return 0 ;
		else
			return 1 ;
	}

	return compare (element1, element2) ;
}

int list_add (list_t * list, const void * element, list_compare_f compare) {
	list_iterator_t * iterator ;
	list_node_t * node ;

	/* write lock */
	iterator = list_iterator_new (list, LIST_WRITE) ;
	if (iterator == NULL)
		return -1 ;

	/* check if the element is already in the list */
	while (list_iterator_next (iterator) == 0) {
		if (list_compare (iterator->current->element, element, compare) == 0) {
			list_iterator_free (iterator) ;
			return -1 ;
		}
	}

	/* create a new node */
	node = malloc (sizeof (list_node_t)) ;
	if (node == NULL) {
		list_iterator_free (iterator) ;
		return -1 ;
	}
	node->element = (void *) element ;
	node->next = NULL ;

	/* add the node at the end of the list */
	if (list->first == NULL)
		list->first = node ;
	else
		list->last->next = node ;
	list->last = node ;
	list->count ++ ;

	/* write unlock */
	if (list_iterator_free (iterator))
		return -1 ;

	return 0 ;
}

void * list_remove (list_t * list, const void * element, list_compare_f compare) {
	void * found = NULL ;
	list_iterator_t * iterator ;

	iterator = list_iterator_new (list, LIST_WRITE) ;
	if (iterator == NULL)
		return NULL ;

	while (list_iterator_next (iterator) == 0) {
		if (list_compare (iterator->current->element, element, compare) == 0) {
			found = iterator->current->element ;
			list_iterator_remove (iterator) ;
			break ;
		}
	}

	if (list_iterator_free (iterator))
		return NULL ;

	return found ;
}

int list_remove_all (list_t * list) {
	int return_code = 0 ;
	list_iterator_t * iterator ;

	iterator = list_iterator_new (list, LIST_WRITE) ;
	if (iterator == NULL)
		return -1 ;

	while (list_iterator_next (iterator) == 0)
		free (iterator->current) ;

	if (list_iterator_free (iterator))
		return_code = -1 ;

	list->first = NULL ;
	list->last = NULL ;
	list->count = 0 ;

	return return_code ;
}

int list_contains (list_t * list, const void * element, list_compare_f compare) {
	int return_code = -1 ;
	list_iterator_t * iterator ;

	iterator = list_iterator_new (list, LIST_READ) ;
	if (iterator == NULL)
		return -1 ;

	while (list_iterator_next (iterator) == 0) {
		if (list_compare (iterator->current->element, element, compare) == 0) {
			return_code = 0 ;
			break ;
		}
	}

	if (list_iterator_free (iterator))
		return -1 ;

	return return_code ;
}

list_iterator_t * list_iterator_new (list_t * list, list_iterator_mode_t mode) {
	list_iterator_t * iterator = NULL ;
	if (list != NULL) {
		iterator = malloc (sizeof (list_iterator_t)) ;
		if (iterator != NULL) {
			switch (mode) {
				case LIST_READ :
					if (pthread_rwlock_rdlock (& list->read_write_lock))
						return NULL ;
					break ;
				case LIST_WRITE :
					if (pthread_rwlock_wrlock (& list->read_write_lock))
						return NULL ;
					break ;
			}
			iterator->list = list ;
			iterator->mode = mode ;
			list_iterator_reset (iterator) ;
		}
	}
	return iterator ;
}

void list_iterator_reset (list_iterator_t * iterator) {
	iterator->previous = NULL ;
	iterator->current = NULL ;
	iterator->next = iterator->list->first ;
	return ;
}

int list_iterator_free (list_iterator_t * iterator) {
	int return_code = 0 ;
	if (iterator == NULL)
		return -1 ;
	return_code |= pthread_rwlock_unlock (& iterator->list->read_write_lock) ;
	free (iterator) ;
	return return_code ;
}

int list_iterator_next (list_iterator_t * iterator) {
	if (! list_iterator_has_next (iterator))
		return -1 ;

	iterator->previous = iterator->current ;
	if (iterator->current == NULL)
		iterator->current = iterator->list->first ;
	else
		iterator->current = iterator->next ;
	iterator->next = iterator->current->next ;

	return 0 ;
}

int list_iterator_has_next (const list_iterator_t * iterator) {
	return iterator->next != NULL ;
}

void * list_iterator_get (list_iterator_t * iterator) {
	if (iterator->current == NULL)
		return NULL ;
	return iterator->current->element ;
}

int list_iterator_remove (list_iterator_t * iterator) {
	if (iterator->mode != LIST_WRITE)
		return -1 ;
	if (iterator->current == NULL)
		return -1 ;

	/* link the previous node to the next one */
	if (iterator->previous == NULL)
		iterator->list->first = iterator->next ;
	else
		iterator->previous->next = iterator->next ;

	/* refresh the last node if needed */
	if (iterator->next == NULL)
		iterator->list->last = iterator->previous ;

	/* decrement count */
	iterator->list->count -- ;

	/* free node */
	free (iterator->current) ;

	/* mark current node as removed */
	iterator->current = NULL ;

	return 0 ;
}
