/*
 * Binary module version check utility.
 *
 * Based on script/mod/modpost.c
 * Copyright 2003       Kai Germaschewski
 * Copyright 2002-2004  Rusty Russell, IBM Corporation
 * Copyright 2006-2008  Sam Ravnborg
 *
 * and based on kernel/module.c
 * Copyright (C) 2002 Richard Henderson
 * Copyright (C) 2001 Rusty Russell, 2002 Rusty Russell IBM.
 *
 * This software may be used and distributed according to the terms
 * of the GNU General Public License, incorporated herein by reference.
 */

#define _GNU_SOURCE
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <stdarg.h>
#include <unistd.h>
#include <fcntl.h>
#include <ctype.h>
#include <sys/mman.h>
#include <elf.h>

#include "elfconfig.h"
#include "modcheck.h"

int quiet = 0;

enum export {
	export_plain,      export_unused,     export_gpl,
	export_unused_gpl, export_gpl_future, export_unknown
};

/*
 * Helpers.
 */

static void fatal(const char *fmt, ...)
{
	va_list arglist;

	if (quiet)
		goto exit;

	fprintf(stderr, "FATAL: ");

	va_start(arglist, fmt);
	vfprintf(stderr, fmt, arglist);
	va_end(arglist);

exit:
	exit(1);
}

static void warn(const char *fmt, ...)
{
	va_list arglist;

	if (quiet)
		return;

	fprintf(stderr, "WARNING: ");

	va_start(arglist, fmt);
	vfprintf(stderr, fmt, arglist);
	va_end(arglist);
}

static void merror(const char *fmt, ...)
{
	va_list arglist;

	if (quiet)
		return;

	fprintf(stderr, "ERROR: ");

	va_start(arglist, fmt);
	vfprintf(stderr, fmt, arglist);
	va_end(arglist);
}

#define NOFAIL(ptr)	do_nofail((ptr), #ptr)

void *do_nofail(void *ptr, const char *expr)
{
	if (!ptr)
		fatal("modpost: Memory allocation failure: %s.\n", expr);

	return ptr;
}

static int is_vmlinux(const char *modname)
{
	const char *myname;

	myname = strrchr(modname, '/');
	if (myname)
		myname++;
	else
		myname = modname;

	return (strcmp(myname, "vmlinux") == 0) ||
	       (strcmp(myname, "vmlinux.o") == 0);
}

/*
 * A list of all modules we processed.
 */

static struct module *modules;

static struct module *find_module(char *modname)
{
	struct module *mod;

	for (mod = modules; mod; mod = mod->next)
		if (strcmp(mod->name, modname) == 0)
			break;
	return mod;
}

static struct module *new_module(char *modname)
{
	struct module *mod;
	char *p, *s;

	mod = NOFAIL(malloc(sizeof(*mod)));
	memset(mod, 0, sizeof(*mod));
	p = NOFAIL(strdup(modname));

	/* strip trailing .o */
	s = strrchr(p, '.');
	if (s != NULL)
		if (strcmp(s, ".o") == 0)
			*s = '\0';

	/* add to list */
	mod->name = p;
	mod->gpl_compatible = -1;
	mod->next = modules;
	modules = mod;

	return mod;
}

/*
 * Hash of all exported symbols.
 */

#define SYMBOL_HASH_SIZE 1024

struct symbol {
	struct symbol *next;
	struct module *module;
	unsigned int crc;
	int crc_valid;
	unsigned int vmlinux:1;    /* 1 if symbol is defined in vmlinux */
	enum export  export;       /* Type of export */
	char name[0];
};

static struct symbol *symbolhash[SYMBOL_HASH_SIZE];

/* This is based on the hash agorithm from gdbm, via tdb */
static inline unsigned int tdb_hash(const char *name)
{
	unsigned value;	/* Used to compute the hash value.  */
	unsigned   i;	/* Used to cycle through random values. */

	/* Set the initial value from the key size. */
	for (value = 0x238F13AF * strlen(name), i = 0; name[i]; i++)
		value = (value + (((unsigned char *)name)[i] << (i*5 % 24)));

	return (1103515243 * value + 12345);
}

/**
 * Allocate a new symbols for use in the hash of exported symbols or
 * the list of unresolved symbols per module
 **/
static struct symbol *alloc_symbol(const char *name, struct symbol *next)
{
	struct symbol *s = NOFAIL(malloc(sizeof(*s) + strlen(name) + 1));

	memset(s, 0, sizeof(*s));
	strcpy(s->name, name);
	s->next = next;
	return s;
}

/* For the hash of exported symbols */
static struct symbol *new_symbol(const char *name, struct module *module,
				 enum export export)
{
	unsigned int hash;
	struct symbol *new;

	hash = tdb_hash(name) % SYMBOL_HASH_SIZE;
	new = symbolhash[hash] = alloc_symbol(name, symbolhash[hash]);
	new->module = module;
	new->export = export;
	new->vmlinux = is_vmlinux(module->name);

	return new;
}

static struct symbol *find_symbol(const char *name)
{
	struct symbol *s;

	/* For our purposes, .foo matches foo.  PPC64 needs this. */
	if (name[0] == '.')
		name++;

	for (s = symbolhash[tdb_hash(name) % SYMBOL_HASH_SIZE]; s; s = s->next) {
		if (strcmp(s->name, name) == 0)
			return s;
	}
	return NULL;
}

static struct {
	const char *str;
	enum export export;
} export_list[] = {
	{ .str = "EXPORT_SYMBOL",            .export = export_plain },
	{ .str = "EXPORT_UNUSED_SYMBOL",     .export = export_unused },
	{ .str = "EXPORT_SYMBOL_GPL",        .export = export_gpl },
	{ .str = "EXPORT_UNUSED_SYMBOL_GPL", .export = export_unused_gpl },
	{ .str = "EXPORT_SYMBOL_GPL_FUTURE", .export = export_gpl_future },
	{ .str = "(unknown)",                .export = export_unknown },
};

static enum export export_no(const char *s)
{
	int i;

	if (!s)
		return export_unknown;
	for (i = 0; export_list[i].export != export_unknown; i++) {
		if (strcmp(export_list[i].str, s) == 0)
			return export_list[i].export;
	}
	return export_unknown;
}

/**
 * Add an exported symbol - it may have already been added without a
 * CRC, in this case just update the CRC
 **/
static struct symbol *sym_add_exported(const char *name, struct module *mod,
				       enum export export)
{
	struct symbol *s = find_symbol(name);

	if (s) {
		warn("%s: '%s' exported twice. Previous export was in %s%s\n",
		     mod->name, name, s->module->name,
		     is_vmlinux(s->module->name) ?"":".ko");
		return NULL;
	}
	return new_symbol(name, mod, export);
}

static void sym_update_crc(const char *name, struct module *mod,
			   unsigned int crc, enum export export)
{
	struct symbol *s = find_symbol(name);

	if (!s)
		s = new_symbol(name, mod, export);
	s->crc = crc;
	s->crc_valid = 1;
}

/*
 * File mapping.
 */

static void *grab_file(const char *filename, unsigned long *size)
{
	struct stat st;
	void *map;
	int fd;

	fd = open(filename, O_RDONLY);
	if (fd < 0 || fstat(fd, &st) != 0)
		return NULL;

	*size = st.st_size;
	map = mmap(NULL, *size, PROT_READ|PROT_WRITE, MAP_PRIVATE, fd, 0);
	close(fd);

	if (map == MAP_FAILED)
		return NULL;
	return map;
}

static void release_file(void *file, unsigned long size)
{
	munmap(file, size);
}

/*
 * Functions to parse symbol dump files.
 */

/*
 * Return a copy of the next line in a mmap'ed file.
 * spaces in the beginning of the line is trimmed away.
 * Return a pointer to a static buffer.
 */
char *get_next_line(unsigned long *pos, void *file, unsigned long size)
{
	static char line[4096];
	int skip = 1;
	size_t len = 0;
	signed char *p = (signed char *)file + *pos;
	char *s = line;

	for (; *pos < size ; (*pos)++) {
		if (skip && isspace(*p)) {
			p++;
			continue;
		}
		skip = 0;
		if (*p != '\n' && (*pos < size)) {
			len++;
			*s++ = *p++;
			if (len > 4095)
				break; /* Too long, stop */
		} else {
			/* End of string */
			*s = '\0';
			return line;
		}
	}
	/* End of buffer */
	return NULL;
}

#define SYMBOL_HASH_SIZE 1024

/* parse Module.symvers file. line format:
 * 0x12345678<tab>symbol<tab>module[[<tab>export]<tab>something]
 **/
static int read_dump(const char *fname)
{
	unsigned long size, pos = 0;
	void *file = grab_file(fname, &size);
	char *line;

	if (!file) {
		perror(fname);
		goto err;
	}

	while ((line = get_next_line(&pos, file, size))) {
		char *symname, *modname, *d, *export, *end;
		unsigned int crc;
		struct module *mod;
		struct symbol *s;

		if (!(symname = strchr(line, '\t')))
			goto err_parse;
		*symname++ = '\0';
		if (!(modname = strchr(symname, '\t')))
			goto err_parse;
		*modname++ = '\0';
		if ((export = strchr(modname, '\t')) != NULL)
			*export++ = '\0';
		if (export && ((end = strchr(export, '\t')) != NULL))
			*end = '\0';
		crc = strtoul(line, &d, 16);
		if (*symname == '\0' || *modname == '\0' || *d != '\0')
			goto err_parse;
		mod = find_module(modname);
		if (!mod)
			mod = new_module(modname);

		s = sym_add_exported(symname, mod, export_no(export));
		sym_update_crc(symname, mod, crc, export_no(export));
	}

	return 0;

err_parse:
	merror("parse error in symbol dump file\n");
	release_file(file, size);
err:
	return -1;
}

/*
 * Elf functions.
 */

#define MODULE_NAME_LEN (64 - sizeof(unsigned long))

struct modversion_info
{
	unsigned int crc;
	char name[MODULE_NAME_LEN];
};

static int check_symbol_version(Elf_Shdr *sechdrs, unsigned int versindex)
{
	struct modversion_info *versions = (void *) sechdrs[versindex].sh_addr;
	unsigned int num_versions =
		sechdrs[versindex].sh_size / sizeof(struct modversion_info);
	int i;

	for (i = 0; i < num_versions; i++) {
		struct symbol *sym = find_symbol(versions[i].name);

		if (!sym) {
			merror("'%s' no found in kernel symbol dump\n",
			       versions[i].name);
			return -1;
		}
		if (!sym->crc_valid || versions[i].crc != sym->crc) {
			merror("'%s' crc mismatch: %s:%#8x differ from %#8x\n",
			       sym->module->name, sym->name, sym->crc,
			       versions[i].crc);
			return -1;
		}
	}

	return 0;
}

/* Find a module section: 0 means not found. */
static unsigned int find_sec(Elf_Ehdr *hdr,
			     Elf_Shdr *sechdrs,
			     const char *secstrings,
			     const char *name)
{
	unsigned int i;

	for (i = 1; i < hdr->e_shnum; i++)
		/* Alloc bit cleared means "ignore it." */
		if ((sechdrs[i].sh_flags & SHF_ALLOC)
		    && strcmp(secstrings+sechdrs[i].sh_name, name) == 0)
			return i;
	return 0;
}

static int check_modversion(struct elf_info *info, const char *modname)
{
	Elf_Ehdr *hdr = info->hdr;
	Elf_Shdr *sechdrs = info->sechdrs;
	char *secstrings;
	unsigned int modindex, versindex;

	secstrings = (void *)hdr + sechdrs[hdr->e_shstrndx].sh_offset;
	sechdrs[0].sh_addr = 0;

	modindex = find_sec(hdr, sechdrs, secstrings,
			    ".gnu.linkonce.this_module");
	if (!modindex) {
		merror("%s is not a module.\n", modname);
		return -1;
	}

	versindex = find_sec(hdr, sechdrs, secstrings, "__versions");
	if (!versindex) {
		merror("%s: missing __versions section.\n", modname);
		return -1;
	}

	return check_symbol_version(sechdrs, versindex);
}

static int parse_elf(struct elf_info *info, const char *filename)
{
	unsigned int i;
	Elf_Ehdr *hdr = info->hdr;
	Elf_Shdr *sechdrs;
	Elf_Sym  *sym;

	if (info->size < sizeof(*hdr)) {
		/* file too small, assume this is an empty .o file */
		merror("'%s' is not a valid ELF file (too small)\n", filename);
		return -1;
	}
	/* Is this a valid ELF file? */
	if ((hdr->e_ident[EI_MAG0] != ELFMAG0) ||
	    (hdr->e_ident[EI_MAG1] != ELFMAG1) ||
	    (hdr->e_ident[EI_MAG2] != ELFMAG2) ||
	    (hdr->e_ident[EI_MAG3] != ELFMAG3)) {
		merror("'%s' is not a valid ELF file\n", filename);
		return -1;
	}
	/* Fix endianness in ELF header */
	hdr->e_type      = TO_NATIVE(hdr->e_type);
	hdr->e_machine   = TO_NATIVE(hdr->e_machine);
	hdr->e_version   = TO_NATIVE(hdr->e_version);
	hdr->e_entry     = TO_NATIVE(hdr->e_entry);
	hdr->e_phoff     = TO_NATIVE(hdr->e_phoff);
	hdr->e_shoff     = TO_NATIVE(hdr->e_shoff);
	hdr->e_flags     = TO_NATIVE(hdr->e_flags);
	hdr->e_ehsize    = TO_NATIVE(hdr->e_ehsize);
	hdr->e_phentsize = TO_NATIVE(hdr->e_phentsize);
	hdr->e_phnum     = TO_NATIVE(hdr->e_phnum);
	hdr->e_shentsize = TO_NATIVE(hdr->e_shentsize);
	hdr->e_shnum     = TO_NATIVE(hdr->e_shnum);
	hdr->e_shstrndx  = TO_NATIVE(hdr->e_shstrndx);
	sechdrs = (void *)hdr + hdr->e_shoff;
	info->sechdrs = sechdrs;

	/* Check if file offset is correct */
	if (hdr->e_shoff > info->size) {
		merror("section header offset=%lu in file '%s' is bigger than "
		       "filesize=%lu\n", (unsigned long)hdr->e_shoff, filename,
		       info->size);
		return -1;
	}

	/* Fix endianness in section headers */
	for (i = 0; i < hdr->e_shnum; i++) {
		sechdrs[i].sh_name      = TO_NATIVE(sechdrs[i].sh_name);
		sechdrs[i].sh_type      = TO_NATIVE(sechdrs[i].sh_type);
		sechdrs[i].sh_flags     = TO_NATIVE(sechdrs[i].sh_flags);
		sechdrs[i].sh_addr      = TO_NATIVE(sechdrs[i].sh_addr);
		sechdrs[i].sh_offset    = TO_NATIVE(sechdrs[i].sh_offset);
		sechdrs[i].sh_size      = TO_NATIVE(sechdrs[i].sh_size);
		sechdrs[i].sh_link      = TO_NATIVE(sechdrs[i].sh_link);
		sechdrs[i].sh_info      = TO_NATIVE(sechdrs[i].sh_info);
		sechdrs[i].sh_addralign = TO_NATIVE(sechdrs[i].sh_addralign);
		sechdrs[i].sh_entsize   = TO_NATIVE(sechdrs[i].sh_entsize);
	}
	/* Find symbol table. */
	for (i = 1; i < hdr->e_shnum; i++) {
		const char *secstrings
			= (void *)hdr + sechdrs[hdr->e_shstrndx].sh_offset;
		const char *secname;
		int nobits = sechdrs[i].sh_type == SHT_NOBITS;

		if (!nobits && sechdrs[i].sh_offset > info->size) {
			merror("%s is truncated. sechdrs[i].sh_offset=%lu > "
			       "sizeof(*hrd)=%zu\n", filename,
			       (unsigned long)sechdrs[i].sh_offset,
			       sizeof(*hdr));
			return -1;
		}

		/* Mark all sections sh_addr with their address in the
		   temporary image. */
		sechdrs[i].sh_addr = (size_t)hdr + sechdrs[i].sh_offset;

		secname = secstrings + sechdrs[i].sh_name;
		if (strcmp(secname, ".modinfo") == 0) {
			if (nobits) {
				merror("%s has NOBITS .modinfo\n", filename);
				return -1;
			}
			info->modinfo = (void *)hdr + sechdrs[i].sh_offset;
			info->modinfo_len = sechdrs[i].sh_size;
		} else if (strcmp(secname, "__ksymtab") == 0)
			info->export_sec = i;
		else if (strcmp(secname, "__ksymtab_unused") == 0)
			info->export_unused_sec = i;
		else if (strcmp(secname, "__ksymtab_gpl") == 0)
			info->export_gpl_sec = i;
		else if (strcmp(secname, "__ksymtab_unused_gpl") == 0)
			info->export_unused_gpl_sec = i;
		else if (strcmp(secname, "__ksymtab_gpl_future") == 0)
			info->export_gpl_future_sec = i;
		else if (strcmp(secname, "__markers_strings") == 0)
			info->markers_strings_sec = i;

		if (sechdrs[i].sh_type != SHT_SYMTAB)
			continue;

		info->symtab_start = (void *)hdr + sechdrs[i].sh_offset;
		info->symtab_stop  = (void *)hdr + sechdrs[i].sh_offset
			                         + sechdrs[i].sh_size;
		info->strtab       = (void *)hdr +
			             sechdrs[sechdrs[i].sh_link].sh_offset;
	}
	if (!info->symtab_start) {
		merror("'%s' has no symbols (stripped ?)\n", filename);
		return -1;
	}

	/* Fix endianness in symbols */
	for (sym = info->symtab_start; sym < info->symtab_stop; sym++) {
		sym->st_shndx = TO_NATIVE(sym->st_shndx);
		sym->st_name  = TO_NATIVE(sym->st_name);
		sym->st_value = TO_NATIVE(sym->st_value);
		sym->st_size  = TO_NATIVE(sym->st_size);
	}
	return 0;
}

/*
 * Main.
 */

static void usage(void)
{
	printf("Usage: modcheck [-q] [-s <symbol dump>] ... <module.ko>\n");
	printf("\t-q : do not print any messages\n");
}

struct ext_sym_list {
	struct ext_sym_list *next;
	const char *file;
};

int main(int argc, char **argv)
{
	int opt;
	int err;
	struct elf_info info = { };
	char *modname;
	struct ext_sym_list *extsym_iter;
	struct ext_sym_list *extsym_start = NULL;

	while ((opt = getopt(argc, argv, "qs:")) != -1) {
		switch (opt) {
		case 's':
			extsym_iter = malloc(sizeof(*extsym_iter));
			if (!extsym_iter)
				return -1;
			extsym_iter->next = extsym_start;
			extsym_iter->file = optarg;
			extsym_start = extsym_iter;
			break;
		case 'q':
			quiet = 1;
			break;
		default:
			usage();
			return -1;
		}
	}
	if (optind != argc - 1) {
		usage();
		return -1;
	}

	modname = argv[optind];

	/* Build symbol list from the supplied Module.symvers files. */
	while (extsym_start) {
		err = read_dump(extsym_start->file);
		if (err)
			goto exit;
		extsym_iter = extsym_start->next;
		free(extsym_start);
		extsym_start = extsym_iter;
	}

	info.hdr = grab_file(modname, &info.size);
	if (!info.hdr) {
		perror(modname);
		err = -1;
		goto exit;
	}

	/* Parse and check elf header. */
	err = parse_elf(&info, modname);
	if (err)
		goto exit_free;

	/* Check symbol version. */
	err = check_modversion(&info, modname);

exit_free:
	release_file(info.hdr, info.size);
exit:
	return err;
}
