#include <u.h>
#include <libc.h>
#include <thread.h>
#include <bio.h>
#include <ip.h>
#include <libsec.h>
#include "dat.h"
#include "fns.h"
typedef struct TTLS {
uchar tp;
uchar flags;
uchar tln[4]; //optional, present if L flag set
} TTLS;
enum {
TtlsFlagL = 1<<7, // header contains tln field
TtlsFlagM = 1<<6, // more fragment(s) will follow for current msg
TtlsFlagS = 1<<5, // start of tls session
TtlsVersion = (1<<2)|(1<<1)|(1<<0),
TtlsShortHlen = 2, // without tln field
TtlsLongHlen = TtlsShortHlen+4, // with tln field
Init = 0,
Waiting,
Sending,
RecvAck,
Receiving,
SendAck,
Received,
Nrbuf = 16,
Nbuf = 9600,
};
char *snames[] = {
[Init] "Init",
[Waiting] "Waiting",
[Sending] "Sending",
[RecvAck] "RecvAck",
[Receiving] "Receiving",
[SendAck] "SendAck",
[Received] "Received",
};
typedef struct Buf {
int n;
uchar b[Nbuf];
} Buf;
typedef struct TTLSstate {
TLSconn tlsconn; // our handle to the tls connection
int tlspipe[2]; // double pipe over which we talk with our tls
// the stuff we read from it has to be fragmented, encapsulated and sent
// the fragments we receive have to be reassembled and then written to it
Channel *tlsfdc; // used to send file desc we get from tlsClient
Channel *readc; // contains index in rbuf containing last msg read from tlspipe
Channel *eofc; // confirm eof on tlspipe
int tlsfd;
int ttslTxLen; // length of frame we prepared for sending
int ttlsDone; // done processing the frame (and, if needed, preparing the response)?
int ttlsState; // ttls state we are in
uint ttlsVersion;
Buf rbuf[Nrbuf]; // msg read from the tls pipe, to be sent (possibly in fragments)
int ridx; // index of first free rbuf
int sendT; // total length of msg to be sent
uint sendL; // length remaining to be sent
uchar*sendP; // pointer in rbuf[...] pointing to stuff remaining to be sent
int sendS; // still have to send first frame (fragment) for current msg?
Buf wbuf; // receive buffer in which we reassemble fragments, and then write to tlspipe
uint recvT; // total length we want to receive (and reassemble)
uint recvL; // length received (and reassembled) so far
uchar*recvP; // first free position (reassembly insert point) in recv buffer
int theReadProc;
int firstAttempt;
} TTLSstate;
static TTLSstate theTTLSstate;
static char errbuf[256];
static void
cleanup(TTLSstate* s)
{
int idx, done;
Alt a[] = {
/* c v op */
{s->eofc, nil, CHANRCV},
{s->readc, &idx, CHANRCV},
{nil, nil, CHANEND},
};
syslog(0, logname, "cleanup pre tlsfd=%d tlspipe[0]=%d tlspipe[1]=%d", s->tlsfd, s->tlspipe[0], s->tlspipe[1]);
if (s->tlsfd >= 0) {
syslog(0, logname, "cleanup: closing tlsfd: %d", s->tlsfd);
close(s->tlsfd); // should make devtls close s->tlspipe[1], causing eof on s->tlspipe[0] in readproc
s->tlsfd = -1;
}
if (s->tlspipe[0] >= 0 || s->tlspipe[1] >= 0) {
// syslog(0, logname, "cleanup receiving...");
done = 0;
while(! done) {
switch(alt(a)){
case 0:
syslog(0, logname, "cleanup: confirmed eof from readproc");
done = 1;
break;
case 1:
syslog(0, logname, "oops... cleanup recv from readc: %d", idx);
break;
}
}
close(s->tlspipe[0]);
s->tlspipe[0] = -1;
s->tlspipe[1] = -1;
}
syslog(0, logname, "cleanup post tlsfd=%d tlspipe[0]=%d tlspipe[1]=%d", s->tlsfd, s->tlspipe[0], s->tlspipe[1]);
}
static void
readproc(void *arg)
{
TTLSstate *s;
Buf *r;
s = arg;
syslog(0, logname, "readproc starts: %d", threadid());
syslog(0, logname, "readproc monitoring pipe: %d", s->tlspipe[0]);
for(;;) {
r = &s->rbuf[s->ridx];
r->n = read(s->tlspipe[0], r->b, Nbuf);
// syslog(0, logname, "readproc read from %d:%d", s->tlspipe[0], r->n);
if (r->n <= 0)
break;
// syslog(0, logname, "readproc sending...");
sendul(s->readc, s->ridx);
s->ridx = (s->ridx+1)%Nrbuf;
}
syslog(0, logname, "readproc eof on pipe: %d", s->tlspipe[0]);
sendul(s->eofc, 0);
syslog(0, logname, "readproc exits: %d", threadid());
threadexits(nil);
}
static void
clientproc(void *arg)
{
TTLSstate *s;
int fd;
s = arg;
syslog(0, logname, "clientproc starts: %d", threadid());
if (s->tlspipe[1] <= 0) {
snprint(errbuf, sizeof(errbuf), "clientproc: no fd for tlsClient:%d", s->tlspipe[1]);
syslog(0, logname, "%s", errbuf);
fprint(2, "%s\n", errbuf);
threadexitsall(errbuf);
}
syslog(0, logname, "calling tlsClient");
fd = tlsClient(s->tlspipe[1], &s->tlsconn);
if (debug) print("clientproc: fd %d\n", fd);
if (fd < 0) {
syslog(0, logname, "tlsClient failed: %r");
fprint(2, "tlsClient failed: %r\n");
threadexitsall("tlsClient failed");
} else {
syslog(0, logname, "tlsClient ok fd=%d", fd);
}
sendul(s->tlsfdc, fd);
syslog(0, logname, "clientproc exits: %d", threadid());
threadexits(nil);
}
static void
setupTls(TTLSstate *s)
{
syslog(0, logname, "setupTls pre tlspipe[0]=%d tlspipe[1]=%d", s->tlspipe[0], s->tlspipe[1]);
if (s->tlspipe[0] >= 0 || s->tlspipe[1] >= 0) {
snprint(errbuf, sizeof(errbuf), "setupTls: pipe already open? %d %d", s->tlspipe[0], s->tlspipe[1]);
fprint(2, "%s\n", errbuf);
syslog(0, logname, "%s", errbuf);
threadexitsall(errbuf);
}
if (pipe(s->tlspipe) < 0) {
fprint(2, "pipe failed: %r\n");
syslog(0, logname, "pipe failed: %r");
threadexitsall("pipe failed");
}
// call tlsClient and wait for result
proccreate(clientproc, s, STACK);
// call tlsClient and wait for result
proccreate(readproc, s, STACK);
syslog(0, logname, "setupTls post tlspipe[0]=%d tlspipe[1]=%d", s->tlspipe[0], s->tlspipe[1]);
}
static int
buildFrameStart(TTLSstate *s, uchar*b, int mtu)
{
TTLS *t;
if (mtu <= TtlsLongHlen)
print("buildFrameStart error: mtu much too small: mtu=%d, longhdr=%d\n", mtu, TtlsLongHlen);
if (s->sendL <= mtu-TtlsLongHlen)
print("buildFrameStart error: small enough, no framing needed: sz=%d, space=%d\n", s->sendL, mtu-TtlsLongHlen);
t = (TTLS*)b;
memset(t, 0, TtlsLongHlen);
t->tp = EapTpTtls;
t->flags = TtlsFlagM | TtlsFlagL;
hnputl(t->tln, s->sendL);
memcpy(b+TtlsLongHlen, s->sendP, mtu-TtlsLongHlen);
s->ttslTxLen = mtu;
s->sendP += mtu-TtlsLongHlen;
s->sendL -= mtu-TtlsLongHlen;
return mtu;
}
static int
buildFrameMiddle(TTLSstate *s, uchar*b, int mtu)
{
TTLS *t;
if (mtu <= TtlsShortHlen)
print("buildFrameMiddle error: mtu much too small: mtu=%d, longhdr=%d\n", mtu, TtlsShortHlen);
if (s->sendL <= mtu-TtlsShortHlen)
print("buildFrameMiddle error: small enough, no framing needed: sz=%d, space=%d\n", s->sendL, mtu-TtlsShortHlen);
t = (TTLS*)b;
memset(t, 0, TtlsShortHlen);
t->tp = EapTpTtls;
t->flags = TtlsFlagM;
memcpy(b+TtlsShortHlen, s->sendP, mtu-TtlsShortHlen);
s->ttslTxLen = mtu;
s->sendP += mtu-TtlsShortHlen;
s->sendL -= mtu-TtlsShortHlen;
return mtu;
}
static int
buildMsg(TTLSstate *s, uchar*b, int mtu)
{
TTLS *t;
int res;
if (mtu <= TtlsShortHlen)
print("buildMsg error: mtu much too small: mtu=%d, longhdr=%d\n", mtu, TtlsShortHlen);
if (s->sendL > mtu-TtlsShortHlen)
print("buildMsg error: too big, framing needed: sz=%d, space=%d\n", s->sendL, mtu-TtlsShortHlen);
t = (TTLS*)b;
memset(t, 0, TtlsShortHlen);
t->tp = EapTpTtls;
memcpy(b+TtlsShortHlen, s->sendP, s->sendL);
s->ttslTxLen = TtlsShortHlen + s->sendL;
res = s->sendL;
s->sendP = 0;
s->sendL = 0;
return res;
}
static void
buildAck(TTLSstate *s, uchar*b, int mtu)
{
TTLS *t;
USED(mtu);
t = (TTLS*)b;
memset(t, 0, TtlsShortHlen);
t->tp = EapTpTtls;
s->ttslTxLen = TtlsShortHlen;
}
static void
trans(TTLSstate *s, int new)
{
if (debug) print("ttls trans: %s -> %s\n", (s->ttlsState>=0) ? snames[s->ttlsState] : "-", snames[new]);
switch(new){
case RecvAck:
s->ttlsDone = 1;
break;
case Receiving:
s->ttlsDone = 1;
break;
}
s->ttlsState = new;
}
static void
ttls(TTLSstate *s, uchar*rcvp, uint rcvl, uchar*txp, uint mtu, int*ttlsSuccess, int*ttlsFail)
{
int fd;
int i;
Alt a[] = {
/* c v op */
{s->tlsfdc, &fd, CHANRCV},
{s->readc, &i, CHANRCV},
{nil, nil, CHANEND},
};
TTLS *t;
uchar *p;
uint l;
int n;
// print("ttls %s\n", snames[s->ttlsState]);
switch(s->ttlsState){
case Init:
setupTls(s); // new session
trans(s, Waiting);
break;
case Waiting:
while(s->ttlsState == Waiting) {
switch(alt(a)){
case 0: // the tlsClient call returned, can be success or failure
// if success, start phase 2
// if failure, signal it so cleanup can take place(?)
// actually, in case of failure we never make it here,
// because clientproc already threadexitsall
if (debug) print("ttls tlsfdc: fd=%d\n", fd);
if (fd < 0) {
*ttlsFail = 1;
cleanup(s);
} else {
s->tlsfd = fd;
doTTLSphase2(fd);
}
break;
case 1: // something read from tlspipe: encapsulate and send
s->sendP = s->rbuf[i].b;
s->sendL = s->rbuf[i].n;
s->sendT = s->sendL;
if (debug) print("ttls readc: i=%d sendP=%p sendL=%d\n", i, s->sendP, s->sendL);
s->sendS = 1;
trans(s, Sending);
break;
}
}
break;
case Sending: {
int olen, flen;
if (s->sendS && s->sendL > mtu-TtlsShortHlen){
olen = s->sendL;
flen = buildFrameStart(s, txp, mtu);
if (debug) print("ttls sendS and framed %d of %d, total %d, remains %d\n", flen, olen, s->sendT, s->sendL);
s->sendS = 0;
trans(s, RecvAck);
}else if (s->sendL > mtu-TtlsShortHlen){
olen = s->sendL;
flen = buildFrameMiddle(s, txp, mtu);
if (debug) print("ttls framed %d of %d, total %d, remains %d\n", flen, olen, s->sendT, s->sendL);
trans(s, RecvAck);
}else{
olen = s->sendL;
flen = buildMsg(s, txp, mtu);
if (debug) print("ttls framed %d of %d, total %d, remains %d\n", flen, olen, s->sendT, s->sendL);
s->recvP = s->wbuf.b;
s->recvL = 0;
s->recvT = 0;
trans(s, Receiving);
}
break;
}
case RecvAck:
t = (TTLS*)rcvp;
if (t->flags&TtlsFlagS)
print("tls: unexpected TtlsFlagS in %s\n", snames[s->ttlsState]);
if (t->flags&TtlsFlagM)
print("tls: unexpected TtlsFlagM in %s\n", snames[s->ttlsState]);
if (t->flags&TtlsFlagL)
print("tls: unexpected TtlsFlagL in %s\n", snames[s->ttlsState]);
trans(s, Sending);
break;
case Receiving:
t = (TTLS*)rcvp;
if (t->flags&TtlsFlagS)
print("tls: unexpected TtlsFlagS in %s\n", snames[s->ttlsState]);
if (t->flags&TtlsFlagL && s->recvT > 0)
print("tls: TtlsFlagL when recvT=%d\n", s->recvT);
if (t->flags&TtlsFlagL) {
s->recvT = nhgetl(t->tln);
if (debug) print("ttls: TtlsFlagL len=%d\n", s->recvT);
p = rcvp+TtlsLongHlen;
l = rcvl-TtlsLongHlen;
if (s->recvP != s->wbuf.b)
print("ttls %s: recvP != wbuf.b recvP=%p wbuf.b=%p \n", snames[s->ttlsState], s->recvP, s->wbuf.b);
if (s->recvL != 0)
print("ttls %s: recvL != 0 recvL=%d\n", snames[s->ttlsState], s->recvL);
} else {
p = rcvp+TtlsShortHlen;
l = rcvl-TtlsShortHlen;
}
memcpy(s->recvP, p, l);
s->recvP += l;
s->recvL += l;
if (debug) print("ttls %s: received %d; recvL=%d; recvT=%d\n", snames[s->ttlsState], l, s->recvL, s->recvT);
if (t->flags&TtlsFlagM)
trans(s, SendAck);
else {
if (s->recvT > 0 && s->recvT != s->recvL)
print("ttls : recvT=%d != recvL=%d\n", s->recvT, s->recvL);
if (s->recvL > 0)
trans(s, Received);
else
trans(s, Waiting);
}
break;
case SendAck:
buildAck(s, txp, mtu);
trans(s, Receiving);
break;
case Received:
if (debug) print("ttls %s: writing tlspipe[0]: %s\n", snames[s->ttlsState], hexprefix(s->wbuf.b, s->recvL, 5));
n = write(s->tlspipe[0], s->wbuf.b, s->recvL);
if (n<0)
print("ttls %s: error writing tlspipe[0]: %r\n", snames[s->ttlsState]);
if (n != s->recvL)
print("ttls %s: writing tlspipe[0]: n != recvL n=%d recvL=%d\n", snames[s->ttlsState], n, s->recvL);
if (debug) print("ttls %s: written to tlspipe[0] : %d\n", snames[s->ttlsState], s->recvL);
trans(s, Waiting);
break;
}
// print("ttls .... %s\n", snames[s->ttlsState]);
}
void
initTTLS(void)
{
TTLSstate *s;
syslog(0, logname, "initTTLS");
s = &theTTLSstate;
memset(s, 0, sizeof(TTLSstate));
s->ttlsState = Init;
s->tlsfdc = chancreate(sizeof(int), 0);
s->readc = chancreate(sizeof(int), 0);
s->eofc = chancreate(sizeof(int), 0);
s->tlsfd = -1;
s->tlspipe[0] = -1;
s->tlspipe[1] = -1;
s->tlsconn.sessionType = "ttls";
s->tlsconn.sessionConst = "ttls keying material";
s->tlsconn.sessionKey = theSessionKey;
s->tlsconn.sessionKeylen = sizeof(theSessionKey);
if (debugTLS)
s->tlsconn.trace = print;
}
static void
run(TTLSstate *s, uchar*rcvp, uint rcvl, uchar*txp, uint mtu, int*success, int*failed)
{
s->ttlsDone = 0;
while (!s->ttlsDone)
ttls(s, rcvp, rcvl, txp, mtu, success, failed);
}
int
processTTLS(uchar*rcvp, uint rcvl, int expectStart, uchar*txp, uint mtu, int*success, int*failed)
{
TTLS *hr;
uchar flags, version;
TTLSstate *s;
// if (debug) print("processTTLS br=%p txp=%p mtu=%d bl=%d\n", br, txp, mtu, bl);
s = &theTTLSstate;
hr = (TTLS*)rcvp;
if (hr->tp != EapTpTtls)
return 0; // flag error??
// first thing should be EAP-TTLS start packet
flags = rcvp[1]; // check length
version = flags & TtlsVersion;
if (debug) print("processTTLS flags=%s%s%s ver=%d mtu=%d bl=%d\n",
(flags&TtlsFlagS ? "S":""),
(flags&TtlsFlagM ? "M":""),
(flags&TtlsFlagL ? "L":""),
version, mtu, rcvl);
if (expectStart && !flags&TtlsFlagS) {
fprint(2, "expected EAP-TTLS start packet\n");
syslog(0, logname, "expected EAP-TTLS start packet");
threadexitsall("expected EAP-TTLS start packet");
}
if (flags & TtlsFlagS) {
cleanup(s); // previous session
// ack??
// look for piggy-backed stuff?
s->ttlsVersion = version;
s->ttlsState = Init;
s->ttlsDone = 0;
s->sendP = 0;
s->sendL = 0;
s->sendS = 0;
s->sendT = 0;
s->recvP = 0;
s->recvL = 0;
s->recvT = 0;
// we don't have a client certificate
s->tlsconn.cert = nil;
s->tlsconn.certlen = 0;
// tlsClient does not support session resumption
s->tlsconn.sessionID = 0;
s->tlsconn.sessionIDlen = 0;
// if (debug) print("processTTLS TtlsFlagS version=%d \n", version);
}
run(s, rcvp, rcvl, txp, mtu, success, failed);
return s->ttslTxLen;
}
|