/*
* IP packet filter
*/
#include "u.h"
#include "../port/lib.h"
#include "mem.h"
#include "dat.h"
#include "fns.h"
#include "../port/error.h"
#include "ip.h"
#include "ipv6.h"
typedef struct Ipmuxrock Ipmuxrock;
typedef struct Ipmux Ipmux;
typedef struct Myip4hdr Myip4hdr;
struct Myip4hdr
{
uchar vihl; /* Version and header length */
uchar tos; /* Type of service */
uchar length[2]; /* packet length */
uchar id[2]; /* ip->identification */
uchar frag[2]; /* Fragment information */
uchar ttl; /* Time to live */
uchar proto; /* Protocol */
uchar cksum[2]; /* Header checksum */
uchar src[4]; /* IP source */
uchar dst[4]; /* IP destination */
uchar data[1]; /* start of data */
};
Myip4hdr *ipoff = 0;
enum
{
Tproto,
Tdata,
Tiph,
Tdst,
Tsrc,
Tifc,
Cother = 0,
Cbyte, /* single byte */
Cmbyte, /* single byte with mask */
Cshort, /* single short */
Cmshort, /* single short with mask */
Clong, /* single long */
Cmlong, /* single long with mask */
Cifc,
Cmifc,
};
char *ftname[] =
{
[Tproto] "proto",
[Tdata] "data",
[Tiph] "iph",
[Tdst] "dst",
[Tsrc] "src",
[Tifc] "ifc",
};
/*
* a node in the decision tree
*/
struct Ipmux
{
Ipmux *yes;
Ipmux *no;
uchar type; /* type of field(Txxxx) */
uchar ctype; /* tupe of comparison(Cxxxx) */
uchar len; /* length in bytes of item to compare */
uchar n; /* number of items val points to */
short off; /* offset of comparison */
short eoff; /* end offset of comparison */
uchar skiphdr; /* should offset start after ipheader */
uchar *val;
uchar *mask;
uchar *e; /* val+n*len*/
int ref; /* so we can garbage collect */
Conv *conv;
};
/*
* someplace to hold per conversation data
*/
struct Ipmuxrock
{
Ipmux *chain;
};
static int ipmuxsprint(Ipmux*, int, char*, int);
static void ipmuxkick(void *x);
static char*
skipwhite(char *p)
{
while(*p == ' ' || *p == '\t')
p++;
return p;
}
static char*
follows(char *p, char c)
{
char *f;
f = strchr(p, c);
if(f == nil)
return nil;
*f++ = 0;
f = skipwhite(f);
if(*f == 0)
return nil;
return f;
}
static Ipmux*
parseop(char **pp)
{
char *p = *pp;
int type, off, end, len;
Ipmux *f;
p = skipwhite(p);
if(strncmp(p, "dst", 3) == 0){
type = Tdst;
off = (ulong)(ipoff->dst);
len = IPv4addrlen;
p += 3;
}
else if(strncmp(p, "src", 3) == 0){
type = Tsrc;
off = (ulong)(ipoff->src);
len = IPv4addrlen;
p += 3;
}
else if(strncmp(p, "ifc", 3) == 0){
type = Tifc;
off = -IPv4addrlen;
len = IPv4addrlen;
p += 3;
}
else if(strncmp(p, "proto", 5) == 0){
type = Tproto;
off = (ulong)&(ipoff->proto);
len = 1;
p += 5;
}
else if(strncmp(p, "data", 4) == 0 || strncmp(p, "iph", 3) == 0){
if(strncmp(p, "data", 4) == 0) {
type = Tdata;
p += 4;
}
else {
type = Tiph;
p += 3;
}
p = skipwhite(p);
if(*p != '[')
return nil;
p++;
off = strtoul(p, &p, 0);
if(off < 0 || off > (64-IP4HDR))
return nil;
p = skipwhite(p);
if(*p != ':')
end = off;
else {
p++;
p = skipwhite(p);
end = strtoul(p, &p, 0);
if(end < off)
return nil;
p = skipwhite(p);
}
if(*p != ']')
return nil;
p++;
len = end - off + 1;
}
else
return nil;
f = smalloc(sizeof(*f));
f->type = type;
f->len = len;
f->off = off;
f->val = nil;
f->mask = nil;
f->n = 1;
f->ref = 1;
if(type == Tdata)
f->skiphdr = 1;
else
f->skiphdr = 0;
return f;
}
static int
htoi(char x)
{
if(x >= '0' && x <= '9')
x -= '0';
else if(x >= 'a' && x <= 'f')
x -= 'a' - 10;
else if(x >= 'A' && x <= 'F')
x -= 'A' - 10;
else
x = 0;
return x;
}
static int
hextoi(char *p)
{
return (htoi(p[0])<<4) | htoi(p[1]);
}
static void
parseval(uchar *v, char *p, int len)
{
while(*p && len-- > 0){
*v++ = hextoi(p);
p += 2;
}
}
static Ipmux*
parsemux(char *p)
{
int n, nomask;
Ipmux *f;
char *val;
char *mask;
char *vals[20];
uchar *v;
/* parse operand */
f = parseop(&p);
if(f == nil)
return nil;
/* find value */
val = follows(p, '=');
if(val == nil)
goto parseerror;
/* parse mask */
mask = follows(p, '&');
if(mask != nil){
switch(f->type){
case Tsrc:
case Tdst:
case Tifc:
f->mask = smalloc(f->len);
v4parseip(f->mask, mask);
break;
case Tdata:
case Tiph:
f->mask = smalloc(f->len);
parseval(f->mask, mask, f->len);
break;
default:
goto parseerror;
}
nomask = 0;
} else {
nomask = 1;
f->mask = smalloc(f->len);
memset(f->mask, 0xff, f->len);
}
/* parse vals */
f->n = getfields(val, vals, sizeof(vals)/sizeof(char*), 1, "|");
if(f->n == 0)
goto parseerror;
f->val = smalloc(f->n*f->len);
v = f->val;
for(n = 0; n < f->n; n++){
switch(f->type){
case Tsrc:
case Tdst:
case Tifc:
v4parseip(v, vals[n]);
break;
case Tproto:
case Tdata:
case Tiph:
parseval(v, vals[n], f->len);
break;
}
v += f->len;
}
f->eoff = f->off + f->len;
f->e = f->val + f->n*f->len;
f->ctype = Cother;
if(f->n == 1){
switch(f->len){
case 1:
f->ctype = nomask ? Cbyte : Cmbyte;
break;
case 2:
f->ctype = nomask ? Cshort : Cmshort;
break;
case 4:
if(f->type == Tifc)
f->ctype = nomask ? Cifc : Cmifc;
else
f->ctype = nomask ? Clong : Cmlong;
break;
}
}
return f;
parseerror:
if(f->mask)
free(f->mask);
if(f->val)
free(f->val);
free(f);
return nil;
}
/*
* Compare relative ordering of two ipmuxs. This doesn't compare the
* values, just the fields being looked at.
*
* returns: <0 if a is a more specific match
* 0 if a and b are matching on the same fields
* >0 if b is a more specific match
*/
static int
ipmuxcmp(Ipmux *a, Ipmux *b)
{
int n;
/* compare types, lesser ones are more important */
n = a->type - b->type;
if(n != 0)
return n;
/* compare offsets, call earlier ones more specific */
n = (a->off+((int)a->skiphdr)*(ulong)ipoff->data) -
(b->off+((int)b->skiphdr)*(ulong)ipoff->data);
if(n != 0)
return n;
/* compare match lengths, longer ones are more specific */
n = b->len - a->len;
if(n != 0)
return n;
/*
* if we get here we have two entries matching
* the same bytes of the record. Now check
* the mask for equality. Longer masks are
* more specific.
*/
if(a->mask != nil && b->mask == nil)
return -1;
if(a->mask == nil && b->mask != nil)
return 1;
if(a->mask != nil && b->mask != nil){
n = memcmp(b->mask, a->mask, a->len);
if(n != 0)
return n;
}
return 0;
}
/*
* Compare the values of two ipmuxs. We're assuming that ipmuxcmp
* returned 0 comparing them.
*/
static int
ipmuxvalcmp(Ipmux *a, Ipmux *b)
{
int n;
n = b->len*b->n - a->len*a->n;
if(n != 0)
return n;
return memcmp(a->val, b->val, a->len*a->n);
}
/*
* add onto an existing ipmux chain in the canonical comparison
* order
*/
static void
ipmuxchain(Ipmux **l, Ipmux *f)
{
for(; *l; l = &(*l)->yes)
if(ipmuxcmp(f, *l) < 0)
break;
f->yes = *l;
*l = f;
}
/*
* copy a tree
*/
static Ipmux*
ipmuxcopy(Ipmux *f)
{
Ipmux *nf;
if(f == nil)
return nil;
nf = smalloc(sizeof *nf);
*nf = *f;
nf->no = ipmuxcopy(f->no);
nf->yes = ipmuxcopy(f->yes);
nf->val = smalloc(f->n*f->len);
nf->e = nf->val + f->len*f->n;
memmove(nf->val, f->val, f->n*f->len);
return nf;
}
static void
ipmuxfree(Ipmux *f)
{
if(f->val != nil)
free(f->val);
free(f);
}
static void
ipmuxtreefree(Ipmux *f)
{
if(f == nil)
return;
if(f->no != nil)
ipmuxfree(f->no);
if(f->yes != nil)
ipmuxfree(f->yes);
ipmuxfree(f);
}
/*
* merge two trees
*/
static Ipmux*
ipmuxmerge(Ipmux *a, Ipmux *b)
{
int n;
Ipmux *f;
if(a == nil)
return b;
if(b == nil)
return a;
n = ipmuxcmp(a, b);
if(n < 0){
f = ipmuxcopy(b);
a->yes = ipmuxmerge(a->yes, b);
a->no = ipmuxmerge(a->no, f);
return a;
}
if(n > 0){
f = ipmuxcopy(a);
b->yes = ipmuxmerge(b->yes, a);
b->no = ipmuxmerge(b->no, f);
return b;
}
if(ipmuxvalcmp(a, b) == 0){
a->yes = ipmuxmerge(a->yes, b->yes);
a->no = ipmuxmerge(a->no, b->no);
a->ref++;
ipmuxfree(b);
return a;
}
a->no = ipmuxmerge(a->no, b);
return a;
}
/*
* remove a chain from a demux tree. This is like merging accept that
* we remove instead of insert.
*/
static int
ipmuxremove(Ipmux **l, Ipmux *f)
{
int n, rv;
Ipmux *ft;
if(f == nil)
return 0; /* we've removed it all */
if(*l == nil)
return -1;
ft = *l;
n = ipmuxcmp(ft, f);
if(n < 0){
/* *l is maching an earlier field, descend both paths */
rv = ipmuxremove(&ft->yes, f);
rv += ipmuxremove(&ft->no, f);
return rv;
}
if(n > 0){
/* f represents an earlier field than *l, this should be impossible */
return -1;
}
/* if we get here f and *l are comparing the same fields */
if(ipmuxvalcmp(ft, f) != 0){
/* different values mean mutually exclusive */
return ipmuxremove(&ft->no, f);
}
/* we found a match */
if(--(ft->ref) == 0){
/*
* a dead node implies the whole yes side is also dead.
* since our chain is constrained to be on that side,
* we're done.
*/
ipmuxtreefree(ft->yes);
*l = ft->no;
ipmuxfree(ft);
return 0;
}
/*
* free the rest of the chain. it is constrained to match the
* yes side.
*/
return ipmuxremove(&ft->yes, f->yes);
}
/*
* connection request is a semi separated list of filters
* e.g. proto=17;data[0:4]=11aa22bb;ifc=135.104.9.2&255.255.255.0
*
* there's no protection against overlapping specs.
*/
static char*
ipmuxconnect(Conv *c, char **argv, int argc)
{
int i, n;
char *field[10];
Ipmux *mux, *chain;
Ipmuxrock *r;
Fs *f;
f = c->p->f;
if(argc != 2)
return Ebadarg;
n = getfields(argv[1], field, nelem(field), 1, ";");
if(n <= 0)
return Ebadarg;
chain = nil;
mux = nil;
for(i = 0; i < n; i++){
mux = parsemux(field[i]);
if(mux == nil){
ipmuxtreefree(chain);
return Ebadarg;
}
ipmuxchain(&chain, mux);
}
if(chain == nil)
return Ebadarg;
mux->conv = c;
/* save a copy of the chain so we can later remove it */
mux = ipmuxcopy(chain);
r = (Ipmuxrock*)(c->ptcl);
r->chain = chain;
/* add the chain to the protocol demultiplexor tree */
wlock(f);
f->ipmux->priv = ipmuxmerge(f->ipmux->priv, mux);
wunlock(f);
Fsconnected(c, nil);
return nil;
}
static int
ipmuxstate(Conv *c, char *state, int n)
{
Ipmuxrock *r;
r = (Ipmuxrock*)(c->ptcl);
return ipmuxsprint(r->chain, 0, state, n);
}
static void
ipmuxcreate(Conv *c)
{
Ipmuxrock *r;
c->rq = qopen(64*1024, Qmsg, 0, c);
c->wq = qopen(64*1024, Qkick, ipmuxkick, c);
r = (Ipmuxrock*)(c->ptcl);
r->chain = nil;
}
static char*
ipmuxannounce(Conv*, char**, int)
{
return "ipmux does not support announce";
}
static void
ipmuxclose(Conv *c)
{
Ipmuxrock *r;
Fs *f = c->p->f;
r = (Ipmuxrock*)(c->ptcl);
qclose(c->rq);
qclose(c->wq);
qclose(c->eq);
ipmove(c->laddr, IPnoaddr);
ipmove(c->raddr, IPnoaddr);
c->lport = 0;
c->rport = 0;
wlock(f);
ipmuxremove(&(c->p->priv), r->chain);
wunlock(f);
ipmuxtreefree(r->chain);
r->chain = nil;
}
/*
* takes a fully formed ip packet and just passes it down
* the stack
*/
static void
ipmuxkick(void *x)
{
Conv *c = x;
Block *bp;
bp = qget(c->wq);
if(bp != nil) {
Myip4hdr *ih4 = (Myip4hdr*)(bp->rp);
if((ih4->vihl & 0xF0) != IP_VER6)
ipoput4(c->p->f, bp, 0, ih4->ttl, ih4->tos, nil);
else
ipoput6(c->p->f, bp, 0, ((Ip6hdr*)ih4)->ttl, 0, nil);
}
}
static void
ipmuxiput(Proto *p, Ipifc *ifc, Block *bp)
{
int len, hl;
Fs *f = p->f;
uchar *m, *h, *v, *e, *ve, *hp;
Conv *c;
Ipmux *mux;
Myip4hdr *ip;
Ip6hdr *ip6;
ip = (Myip4hdr*)bp->rp;
hl = (ip->vihl&0x0F)<<2;
if(p->priv == nil)
goto nomatch;
h = bp->rp;
len = BLEN(bp);
/* run the v4 filter */
rlock(f);
c = nil;
mux = f->ipmux->priv;
while(mux != nil){
if(mux->eoff > len){
mux = mux->no;
continue;
}
hp = h + mux->off + ((int)mux->skiphdr)*hl;
switch(mux->ctype){
case Cbyte:
if(*mux->val == *hp)
goto yes;
break;
case Cmbyte:
if((*hp & *mux->mask) == *mux->val)
goto yes;
break;
case Cshort:
if(*((ushort*)mux->val) == *(ushort*)hp)
goto yes;
break;
case Cmshort:
if((*(ushort*)hp & (*((ushort*)mux->mask))) == *((ushort*)mux->val))
goto yes;
break;
case Clong:
if(*((ulong*)mux->val) == *(ulong*)hp)
goto yes;
break;
case Cmlong:
if((*(ulong*)hp & (*((ulong*)mux->mask))) == *((ulong*)mux->val))
goto yes;
break;
case Cifc:
if(*((ulong*)mux->val) == *(ulong*)(ifc->lifc->local + IPv4off))
goto yes;
break;
case Cmifc:
if((*(ulong*)(ifc->lifc->local + IPv4off) & (*((ulong*)mux->mask))) == *((ulong*)mux->val))
goto yes;
break;
default:
v = mux->val;
for(e = mux->e; v < e; v = ve){
m = mux->mask;
hp = h + mux->off;
for(ve = v + mux->len; v < ve; v++){
if((*hp++ & *m++) != *v)
break;
}
if(v == ve)
goto yes;
}
}
mux = mux->no;
continue;
yes:
if(mux->conv != nil)
c = mux->conv;
mux = mux->yes;
}
runlock(f);
if(c != nil){
/* tack on interface address */
bp = padblock(bp, IPaddrlen);
ipmove(bp->rp, ifc->lifc->local);
bp = concatblock(bp);
if(bp != nil)
if(qpass(c->rq, bp) < 0)
print("ipmuxiput: qpass failed\n");
return;
}
nomatch:
/* doesn't match any filter, hand it to the specific protocol handler */
ip = (Myip4hdr*)bp->rp;
if((ip->vihl & 0xF0) == IP_VER4) {
p = f->t2p[ip->proto];
} else {
ip6 = (Ip6hdr*)bp->rp;
p = f->t2p[ip6->proto];
}
if(p && p->rcv)
(*p->rcv)(p, ifc, bp);
else
freeblist(bp);
return;
}
static int
ipmuxsprint(Ipmux *mux, int level, char *buf, int len)
{
int i, j, n;
uchar *v;
n = 0;
for(i = 0; i < level; i++)
n += snprint(buf+n, len-n, " ");
if(mux == nil){
n += snprint(buf+n, len-n, "\n");
return n;
}
n += snprint(buf+n, len-n, "h[%d:%d]&",
mux->off+((int)mux->skiphdr)*((int)ipoff->data),
mux->off+(((int)mux->skiphdr)*((int)ipoff->data))+mux->len-1);
for(i = 0; i < mux->len; i++)
n += snprint(buf+n, len - n, "%2.2ux", mux->mask[i]);
n += snprint(buf+n, len-n, "=");
v = mux->val;
for(j = 0; j < mux->n; j++){
for(i = 0; i < mux->len; i++)
n += snprint(buf+n, len - n, "%2.2ux", *v++);
n += snprint(buf+n, len-n, "|");
}
n += snprint(buf+n, len-n, "\n");
level++;
n += ipmuxsprint(mux->no, level, buf+n, len-n);
n += ipmuxsprint(mux->yes, level, buf+n, len-n);
return n;
}
static int
ipmuxstats(Proto *p, char *buf, int len)
{
int n;
Fs *f = p->f;
rlock(f);
n = ipmuxsprint(p->priv, 0, buf, len);
runlock(f);
return n;
}
void
ipmuxinit(Fs *f)
{
Proto *ipmux;
ipmux = smalloc(sizeof(Proto));
ipmux->priv = nil;
ipmux->name = "ipmux";
ipmux->connect = ipmuxconnect;
ipmux->announce = ipmuxannounce;
ipmux->state = ipmuxstate;
ipmux->create = ipmuxcreate;
ipmux->close = ipmuxclose;
ipmux->rcv = ipmuxiput;
ipmux->ctl = nil;
ipmux->advise = nil;
ipmux->stats = ipmuxstats;
ipmux->ipproto = -1;
ipmux->nc = 64;
ipmux->ptclsize = sizeof(Ipmuxrock);
f->ipmux = ipmux; /* hack for Fsrcvpcol */
Fsproto(f, ipmux);
}
|