/*
* WPA-PSK
*
* Client protocol:
* write challenge: smac[6] + amac[6] + snonce[32] + anonce[32]
* read response: ptk[64]
*
* Server protocol:
* unimplemented
*/
#include "dat.h"
enum {
PMKlen = 256/8,
PTKlen = 512/8,
Eaddrlen = 6,
Noncelen = 32,
};
enum
{
CNeedChal,
CHaveResp,
Maxphase,
};
static char *phasenames[Maxphase] = {
[CNeedChal] "CNeedChal",
[CHaveResp] "CHaveResp",
};
struct State
{
uchar resp[PTKlen];
};
static void
pbkdf2(uchar *p, ulong plen, uchar *s, ulong slen, ulong rounds, uchar *d, ulong dlen)
{
uchar block[SHA1dlen], tmp[SHA1dlen], tmp2[SHA1dlen];
ulong i, j, k, n;
DigestState *ds;
for(i = 1; dlen > 0; i++, d += n, dlen -= n){
tmp[3] = i;
tmp[2] = i >> 8;
tmp[1] = i >> 16;
tmp[0] = i >> 24;
ds = hmac_sha1(s, slen, p, plen, nil, nil);
hmac_sha1(tmp, 4, p, plen, block, ds);
memmove(tmp, block, sizeof(tmp));
for(j = 1; j < rounds; j++){
hmac_sha1(tmp, sizeof(tmp), p, plen, tmp2, nil);
memmove(tmp, tmp2, sizeof(tmp));
for(k=0; k<sizeof(tmp); k++)
block[k] ^= tmp[k];
}
n = dlen > sizeof(block) ? sizeof(block) : dlen;
memmove(d, block, n);
}
}
static int
hextob(char *s, char **sp, uchar *b, int n)
{
int r;
n <<= 1;
for(r = 0; r < n && *s; s++){
*b <<= 4;
if(*s >= '0' && *s <= '9')
*b |= (*s - '0');
else if(*s >= 'a' && *s <= 'f')
*b |= 10+(*s - 'a');
else if(*s >= 'A' && *s <= 'F')
*b |= 10+(*s - 'A');
else break;
if((++r & 1) == 0)
b++;
}
if(sp != nil)
*sp = s;
return r >> 1;
}
static void
pass2pmk(char *pass, char *ssid, uchar pmk[PMKlen])
{
if(hextob(pass, nil, pmk, PMKlen) == PMKlen)
return;
pbkdf2((uchar*)pass, strlen(pass), (uchar*)ssid, strlen(ssid), 4096, pmk, PMKlen);
}
static void
prfn(uchar *k, int klen, char *a, uchar *b, int blen, uchar *d, int dlen)
{
uchar r[SHA1dlen], i;
DigestState *ds;
int n;
i = 0;
while(dlen > 0){
ds = hmac_sha1((uchar*)a, strlen(a)+1, k, klen, nil, nil);
hmac_sha1(b, blen, k, klen, nil, ds);
hmac_sha1(&i, 1, k, klen, r, ds);
i++;
n = dlen;
if(n > sizeof(r))
n = sizeof(r);
memmove(d, r, n); d += n;
dlen -= n;
}
}
static void
calcptk(uchar pmk[PMKlen], uchar smac[Eaddrlen], uchar amac[Eaddrlen],
uchar snonce[Noncelen], uchar anonce[Noncelen],
uchar ptk[PTKlen])
{
uchar b[2*Eaddrlen + 2*Noncelen];
if(memcmp(smac, amac, Eaddrlen) > 0){
memmove(b + Eaddrlen*0, amac, Eaddrlen);
memmove(b + Eaddrlen*1, smac, Eaddrlen);
} else {
memmove(b + Eaddrlen*0, smac, Eaddrlen);
memmove(b + Eaddrlen*1, amac, Eaddrlen);
}
if(memcmp(snonce, anonce, Eaddrlen) > 0){
memmove(b + Eaddrlen*2 + Noncelen*0, anonce, Noncelen);
memmove(b + Eaddrlen*2 + Noncelen*1, snonce, Noncelen);
} else {
memmove(b + Eaddrlen*2 + Noncelen*0, snonce, Noncelen);
memmove(b + Eaddrlen*2 + Noncelen*1, anonce, Noncelen);
}
prfn(pmk, PMKlen, "Pairwise key expansion", b, sizeof(b), ptk, PTKlen);
}
static int
wpapskinit(Proto *p, Fsstate *fss)
{
int iscli;
State *s;
if((iscli = isclient(_strfindattr(fss->attr, "role"))) < 0)
return failure(fss, nil);
if(!iscli)
return failure(fss, "%s server not supported", p->name);
s = emalloc(sizeof *s);
fss->phasename = phasenames;
fss->maxphase = Maxphase;
fss->phase = CNeedChal;
fss->ps = s;
return RpcOk;
}
static int
wpapskwrite(Fsstate *fss, void *va, uint n)
{
uchar pmk[PMKlen], *smac, *amac, *snonce, *anonce;
char *pass, *essid;
State *s;
int ret;
Key *k;
Keyinfo ki;
Attr *attr;
s = fss->ps;
if(fss->phase != CNeedChal)
return phaseerror(fss, "write");
if(n != (2*Eaddrlen + 2*Noncelen))
return phaseerror(fss, "bad write size");
attr = _delattr(_copyattr(fss->attr), "role");
mkkeyinfo(&ki, fss, attr);
ret = findkey(&k, &ki, "%s", fss->proto->keyprompt);
_freeattr(attr);
if(ret != RpcOk)
return ret;
pass = _strfindattr(k->privattr, "!password");
if(pass == nil)
return failure(fss, "key has no password");
essid = _strfindattr(k->attr, "essid");
if(essid == nil)
return failure(fss, "key has no essid");
setattrs(fss->attr, k->attr);
closekey(k);
pass2pmk(pass, essid, pmk);
smac = va;
amac = smac + Eaddrlen;
snonce = amac + Eaddrlen;
anonce = snonce + Noncelen;
calcptk(pmk, smac, amac, snonce, anonce, s->resp);
fss->phase = CHaveResp;
return RpcOk;
}
static int
wpapskread(Fsstate *fss, void *va, uint *n)
{
State *s;
s = fss->ps;
if(fss->phase != CHaveResp)
return phaseerror(fss, "read");
if(*n > sizeof(s->resp))
*n = sizeof(s->resp);
memmove(va, s->resp, *n);
fss->phase = Established;
fss->haveai = 0;
return RpcOk;
}
static void
wpapskclose(Fsstate *fss)
{
State *s;
s = fss->ps;
free(s);
}
Proto wpapsk = {
.name= "wpapsk",
.init= wpapskinit,
.write= wpapskwrite,
.read= wpapskread,
.close= wpapskclose,
.addkey= replacekey,
.keyprompt= "!password? essid?"
};
|