/*
   bcmwpa.c - shared WPA-related functions

   Copyright 2004, Broadcom Corporation
   All Rights Reserved.                
                                       
   This is UNPUBLISHED PROPRIETARY SOURCE CODE of Broadcom Corporation;   
   the contents of this file may not be disclosed to third parties, copied
   or duplicated in any form, in whole or in part, without the prior      
   written permission of Broadcom Corporation.                            

   $Id$
 */

#include <bcmendian.h>
#include <osl.h>
#include <wlioctl.h>
#include <proto/802.11.h>
#include <proto/eapol.h>
#include <bcmutils.h>
#include <bcmwpa.h>

#ifdef	WPAPSK

#include <crypto/prf.h>
#include <crypto/rc4.h>

void
wpa_calc_ptk(struct ether_addr *auth_ea, struct ether_addr *sta_ea,
	     uint8 *anonce, uint8* snonce, uint8 *pmk, uint pmk_len,
	     uint8 *ptk, uint ptk_len)
{
 	uchar data[128], prf_buff[PRF_OUTBUF_LEN];
	uchar prefix[] = "Pairwise key expansion";
	int data_len=0;

	/* Create the the data portion:
	   the lesser of the EAs, followed by the greater of the EAs,
	   followed by the lesser of the the nonces, followed by the
	   greater of the nonces. */
	bcopy(wpa_array_cmp(MIN_ARRAY, (uint8 *)auth_ea, (uint8 *)sta_ea,
			    ETHER_ADDR_LEN),
	      (char *)&data[data_len], ETHER_ADDR_LEN);
	data_len += ETHER_ADDR_LEN;
	bcopy(wpa_array_cmp(MAX_ARRAY, (uint8 *)auth_ea, (uint8 *)sta_ea,
			    ETHER_ADDR_LEN),
	      (char *)&data[data_len], ETHER_ADDR_LEN);
	data_len += ETHER_ADDR_LEN;
	bcopy(wpa_array_cmp(MIN_ARRAY, snonce, anonce,
			    EAPOL_WPA_KEY_NONCE_LEN),
	      (char *)&data[data_len], EAPOL_WPA_KEY_NONCE_LEN);
	data_len += EAPOL_WPA_KEY_NONCE_LEN;
	bcopy(wpa_array_cmp(MAX_ARRAY, snonce, anonce,
			    EAPOL_WPA_KEY_NONCE_LEN),
	      (char *)&data[data_len], EAPOL_WPA_KEY_NONCE_LEN);
	data_len += EAPOL_WPA_KEY_NONCE_LEN;

	/* generate the PTK */
	fPRF(pmk, (int)pmk_len, prefix, strlen(prefix), data, data_len,
	     prf_buff, (int)ptk_len);
	bcopy(prf_buff, (char*)ptk, ptk_len);
}

/* Decrypt a group transient key from a WPA key message */
bool
wpa_decr_gtk(eapol_wpa_key_header_t *body, uint16 key_info, uint8 *ekey,
	     uint8 *gtk)
{
	unsigned char data[256], encrkey[WPA_MIC_KEY_LEN*2];
	rc4_ks_t rc4key;
	uint16 len = ntoh16_ua((uint8 *)&body->key_len);

	switch (key_info & (WPA_KEY_DESC_V1 | WPA_KEY_DESC_V2)) {
	case WPA_KEY_DESC_V1:
		bcopy(body->iv, encrkey, WPA_MIC_KEY_LEN);
		bcopy(ekey, &encrkey[WPA_MIC_KEY_LEN], WPA_MIC_KEY_LEN);
		/* decrypt the gtk using RC4 */
		prepare_key(encrkey, sizeof(encrkey), &rc4key);
		rc4(data, sizeof(data), &rc4key); /* dump 256 bytes */
		rc4(body->data, len, &rc4key);
		bcopy(body->data, gtk, len);
		break;

	case WPA_KEY_DESC_V2:
		if (aes_unwrap(WPA_MIC_KEY_LEN, ekey,
			       ntoh16_ua((uint8 *)&body->data_len),
			       body->data, gtk)) {
			return FALSE;
		}
		break;

	default:
		return FALSE;
	}
	return TRUE;
}

/* Compute Message Integrity Code (MIC) over EAPOL message */
bool
wpa_make_mic(eapol_header_t *eapol, uint key_desc, uint8 *mic_key, uchar *mic)
{
	eapol_wpa_key_header_t *body = (eapol_wpa_key_header_t *)eapol->body;
	int mic_length;

	/* length of eapol pkt from the version field on */
	mic_length =  4 + EAPOL_WPA_KEY_LEN + ntoh16_ua((uint8 *)&body->data_len);

	/* Create the MIC for the pkt */
	switch (key_desc) {
	case WPA_KEY_DESC_V1:
		hmac_md5(&eapol->version, mic_length, mic_key,
			 EAPOL_WPA_KEY_MIC_LEN, mic);
		break;
	case WPA_KEY_DESC_V2:
		hmac_sha1(&eapol->version, mic_length, mic_key,
			  EAPOL_WPA_KEY_MIC_LEN, mic);
		break;
	default:
		return FALSE;
	}
	return TRUE;
}

/* Check MIC of EAPOL message */
bool
wpa_check_mic(eapol_header_t *eapol, uint key_desc, uint8 *mic_key)
{
	eapol_wpa_key_header_t *body = (eapol_wpa_key_header_t *)eapol->body;
	uchar digest[PRF_OUTBUF_LEN];
	uchar mic[EAPOL_WPA_KEY_MIC_LEN];

	/* save MIC and clear its space in message */
	bcopy((char*)&body->mic, mic, EAPOL_WPA_KEY_MIC_LEN);
	bzero((char*)&body->mic, EAPOL_WPA_KEY_MIC_LEN);

	if (!wpa_make_mic(eapol, key_desc, mic_key, digest)) {
		return FALSE;
	}
	return !bcmp(digest, mic, EAPOL_WPA_KEY_MIC_LEN);
}
#endif	/* WPAPSK */

/* Convert WPA IE cipher suite to locally used value */
bool
wpa_cipher(wpa_suite_t *suite, ushort *cipher, bool wep_ok)
{
	bool ret = TRUE;

	if (!bcmp((char *)suite->oui, WPA_OUI, WPA_OUI_LEN)) {
		switch (suite->type) {
		case WPA_CIPHER_TKIP:
			*cipher = CRYPTO_ALGO_TKIP;
			break;
		case WPA_CIPHER_AES_CCM:
			*cipher = CRYPTO_ALGO_AES_CCM;
			break;
		case WPA_CIPHER_WEP_40:
			if (wep_ok)
				*cipher = CRYPTO_ALGO_WEP1;
			else
				ret = FALSE;
			break;
		case WPA_CIPHER_WEP_104:
			if (wep_ok)
				*cipher = CRYPTO_ALGO_WEP128;
			else
				ret = FALSE;
			break;
		default:
			ret = FALSE;
			break;
		}
		return ret;
	}
	return FALSE;
}

/* Is this body of this tlvs entry a WPA entry? If */
/* not update the tlvs buffer pointer/length */
bool
bcm_is_wpa_ie(uint8 *ie, uint8 **tlvs, int *tlvs_len)
{
	/* If the contents match the WPA_OUI and type=1 */
	if ((ie[TLV_LEN_OFF] > (TLV_HDR_LEN + WPA_OUI_LEN)) &&
	    !bcmp(&ie[TLV_BODY_OFF], WPA_OUI "\x01", WPA_OUI_LEN + 1)) {
		return TRUE;
	}

	/* point to the next ie */
	ie += ie[TLV_LEN_OFF] + TLV_HDR_LEN;
	/* calculate the length of the rest of the buffer */
	*tlvs_len -= (int)ie - (int)*tlvs;
	/* update the pointer to the start of the buffer */
	*tlvs = ie;

	return FALSE;
}

wpa_ie_fixed_t *
bcm_find_wpaie(uint8 *parse, uint len)
{
	bcm_tlv_t *ie;

	while ((ie = bcm_parse_tlvs(parse, len, DOT11_MNG_WPA_ID))) {
		if (bcm_is_wpa_ie((uint8*)ie, &parse, &len)) {
			return (wpa_ie_fixed_t *)ie;
		}
	}
	return NULL;
}

char *
wpa_array_cmp(int max, uint8 *x, uint8 *y, int len)
{	int i;
	uint8 *ret=x;

	for(i=0; i<len; i++)
		if (x[i] != y[i])
			break;

	if (i==len) {
		return NULL;
	}
	if (max && (y[i] > x[i]))
		ret = y;
	if (!max && (y[i] < x[i]))
		ret = y;

	return(ret);
}

void
wpa_incr_array(uint8 *array, int len)
{
	int i;

	for (i=(len-1); i>=0; i--)
		if (array[i]++ != 0xff) {
			break;
		}
}
