#include <u.h>
#include <libc.h>
#include <auth.h>
#include <libsec.h>
#include <fcall.h>
#include <thread.h>
#include <9p.h>
#include "mysql.h"
enum {
Nhdr = 4, /* packet header length */
Mtu = 1024*64, /* bigger means more efficent but more memory use */
Timeout = 5000, /* network timeout in milliseconds */
utf8_general_ci = 33 /* utf8 language specifier */
};
static Pkt *
mkpkt(Sess *s)
{
Pkt *p;
p = emalloc9p(sizeof(Pkt));
p->buf = emalloc9p(s->mtu);
p->pos = p->buf+Nhdr;
p->end = p->buf+s->mtu;
return p;
}
static int
rstpkt(Pkt *p)
{
int len = p->pos - p->buf;
p->pos = p->buf;
return len;
}
static void
freepkt(Pkt *p)
{
free(p->buf);
free(p);
}
/*******************************************************/
static int
putpkt(Sess *s, Pkt *p, int *seq)
{
int len, got;
len = rstpkt(p) -Nhdr;
p24(p, len);
p8(p, *seq);
if(Debug > 1){
fprint(2, "tx len=%d seq=%d\n", len, *seq);
xd(p->buf, len+Nhdr);
}
*seq = (*seq + 1) % 256;
got = write(s->net, p->buf, len+Nhdr);
if(got != len+Nhdr){
werrstr("write fail - %r");
return -1;
}
return 0;
}
static int
getpkt(Sess *s, Pkt *p, int *rseq)
{
int got, len, seq;
rstpkt(p);
alarm(Timeout);
got = readn(s->net, p->buf, Nhdr);
alarm(0);
if(got != Nhdr){
werrstr("read fail - %r");
return -1;
}
len = g24(p);
seq = g8(p);
if(seq != *rseq){
fprint(2, "bad sequence number (%d != %d)", seq, *rseq);
return -1;
}
*rseq = (*rseq + 1) % 256;
alarm(Timeout);
got = readn(s->net, p->pos, len);
alarm(0);
if(got != len){
werrstr("short packet - %r");
return -1;
}
p->end = p->buf+len+Nhdr;
if(Debug > 1){
fprint(2, "rx len=%d seq=%d\n", len, seq);
xd(p->buf, len+Nhdr);
}
return peek8(p);
}
/*******************************************************/
void
freeres(Results *res)
{
Field *f, *ft;
Row *r, *tr;
Col *c, *tc;
if(res == nil)
return;
for(f = res->fields; f; f = ft){
free(f->cat);
free(f->db);
free(f->tab);
free(f->otab);
free(f->name);
free(f->oname);
ft = f->next;
free(f);
}
for(r = res->rows; r; r = tr){
for(c = r->cols; c; c = tc){
free(c->str);
tc = c->next;
free(c);
}
tr = r->next;
free(r);
}
free(res);
}
static Field *
parsefield(Pkt *p)
{
Field *f;
f = emalloc9p(sizeof(Field));
f->cat = gnstr(p);
f->db = gnstr(p);
f->tab = gnstr(p);
f->otab = gnstr(p);
f->name = gnstr(p);
f->oname = gnstr(p);
g8(p); /* filler */
f->lang = g16(p);
f->size = g32(p);
f->type = g8(p);
f->flags = g16(p);
f->prec = g8(p);
g16(p); /* filler */
if(remain(p)) /* Ugh! */
f->def = gnum(p, &f->hasdef);
return f;
}
static Col *
parsevalue(Pkt *p, Field *f)
{
Col *c;
char *s;
if(remain(p) == 0)
return nil;
c = emalloc9p(sizeof(Col));
c->str = gnstr(p);
/*
* Questionable code - trim trailing zeros from
* decimal values, looks prettier but perhaps it
* breaks somthing, though I cannot see how.
*/
if(f->type == FLDdouble){
s = strchr(c->str, 0) -1;
while(s > c->str && *s == '0')
*s-- = 0;
}
return c;
}
static Results *
results(Sess *s, Pkt *p, int num, int seq)
{
int i, j, rc;
Results *res;
Field *f, **fld;
Row *r, **row;
Col *c, **col;
res = emalloc9p(sizeof(Results));
res->nf = num;
fld = &res->fields;
for(i = 0; i < num; i++){ /* field definitions */
if(getpkt(s, p, &seq) == -1)
return nil;
if((f = parsefield(p)) == nil)
return nil;
if(Debug)
dumpfield(f);
*fld = f;
fld = &f->next;
}
if((rc = getpkt(s, p, &seq)) == -1)
return nil;
if(rc != Reof){
werrstr("results: reply=0x%x unexpected\n", rc);
return nil;
}
row = &res->rows;
while(1){
r = emalloc9p(sizeof(Row));
if((rc = getpkt(s, p, &seq)) == -1)
return nil;
if(rc == Reof)
break;
col = &r->cols;
f = res->fields;
for(j = 0; j < num; j++){
if((c = parsevalue(p, f)) == nil)
break;
*col = c;
col = &c->next;
f = f->next;
}
*row = r;
row = &r->next;
res->nr++;
}
return res;
}
static void
parseerr(Pkt *p)
{
int i, err;
char *msg;
static struct { int err; char *msg; }
sane[] = {
{ 1130, "may not connect from this IP address" },
{ 1043, "authentication protocol botch" },
{ 1044, "access denied" },
{ 1045, "authentication faliure" },
{ 1046, "no database selected" },
{ 1064, "bad query syntax" },
};
g8(p);
err = g16(p); /* error code */
g8(p); /* a single hash character */
gskip(p, 5); /* server state as a 5 char string */
msg = gsall(p); /* error text */
for(i = 0; i < nelem(sane); i++)
if(sane[i].err == err){
werrstr("%s", sane[i].msg);
return;
}
werrstr("e%d: %s", err, msg);
}
/*******************************************************/
/* no locking here as we have not started serving the filesystem yet */
Sess *
mysql_open(char *host)
{
Pkt *p;
Sess *s;
int seq, fd, rc;
if((fd = dial(netmkaddr(host, "tcp", "mysql"), 0, 0, 0)) == -1)
return nil;
s = emalloc9p(sizeof(Sess));
s->net = fd;
s->mtu = Mtu; /* FIXME: guess at MTU */
seq = 0;
p = mkpkt(s);
rc = getpkt(s, p, &seq);
switch(rc){
case 0x0a: /* our only supported protocol version at present */
break;
case -1:
free(s);
close(fd);
freepkt(p);
return nil;
case Rerr:
parseerr(p);
free(s);
close(fd);
freepkt(p);
return nil;
default:
werrstr("version=0x%x unsupported protocol\n", rc);
free(s);
close(fd);
freepkt(p);
return nil;
}
s->proto = g8(p);
s->server = gstr(p);
s->tid = g32(p);
s->salt1 = gstr(p);
s->caps = g16(p);
s->lang = g8(p);
s->stat = g16(p);
gskip(p, 13);
s->salt2 = gstr(p);
freepkt(p);
if((s->caps & CAPprotocol_41) == 0 || (s->caps & CAPauthentication_41) == 0){
werrstr("v%s - old server not supported\n", s->server);
free(s->server);
free(s->salt1);
free(s->salt2);
free(s);
return nil;
}
return s;
}
/* no locking here as we have not started serving the filesystem yet */
int
mysql_auth(Sess *s, char *user, char *resp3x, uchar *resp4x)
{
Pkt *p;
int seq, rc, caps;
caps = s->caps & ~(CAPcompression|CAPssl);
p = mkpkt(s);
p32(p, caps);
p32(p, s->mtu);
p8(p, utf8_general_ci);
pskip(p, 23); /* padding */
pstr(p, user);
if(resp4x){
p8(p, Nauth);
pmem(p, resp4x, Nauth);
}
else{
p8(p, 0);
}
p8(p, 0); /* padding */
seq = 1;
if(putpkt(s, p, &seq) == -1){
freepkt(p);
return -1;
}
rc = getpkt(s, p, &seq);
/*
* Very special case, server supports V4.1 passwords,
* but its password database is still in V 3.x format
* so it has ashed us to resubmit our password
* in this format.
*/
if(rc == Reof && s->caps & CAPauthentication_41){
freepkt(p);
p = mkpkt(s);
pstr(p, resp3x);
if(putpkt(s, p, &seq) == -1){
freepkt(p);
return -1;
}
rc = getpkt(s, p, &seq);
}
switch(rc){
case Rok:
rc = 0;
break;
case -1:
break;
case Rerr:
parseerr(p);
rc = -1;
break;
default:
werrstr("reply=0x%x unexpected\n", rc);
rc = -1;
}
freepkt(p);
return rc;
}
static int
exch(Sess *s, Pkt *p, Results **res)
{
int seq, num, rc;
if(res)
*res = nil;
seq = 0;
lock(&s->lock);
if(putpkt(s, p, &seq) == -1){
unlock(&s->lock);
return -1;
}
rc = getpkt(s, p, &seq);
switch(rc){
case Rok:
rc = 0;
break;
case -1:
break;
case Rerr:
parseerr(p);
rc = -1;
break;
case Reof:
werrstr("reply=0x%x unexpected\n", rc);
rc = -1;
break;
default:
if(res == nil){
werrstr("reply=0x%x unexpected\n", rc);
rc = -1;
break;
}
num = gnum(p, nil); /* # fields */
if(remain(p))
gnum(p, nil); /* extra */
rc = 0;
if((*res = results(s, p, num, seq)) == nil)
rc = -1;
break;
}
unlock(&s->lock);
return rc;
}
int
mysql_ping(Sess *s)
{
Pkt *p;
int rc;
p = mkpkt(s);
p8(p, CMDping);
rc = exch(s, p, nil);
freepkt(p);
return rc;
}
int
mysql_query(Sess *s, char *cmd, Results **res)
{
Pkt *p;
int rc;
p = mkpkt(s);
p8(p, CMDquery);
pstr(p, cmd);
rc = exch(s, p, res);
freepkt(p);
return rc;
}
int
mysql_use(Sess *s, char *db)
{
Pkt *p;
int rc;
p = mkpkt(s);
p8(p, CMDinit_db);
pstr(p, db);
rc = exch(s, p, nil);
freepkt(p);
return rc;
}
int
mysql_ps(Sess *s, Results **res)
{
Pkt *p;
int rc;
p = mkpkt(s);
p8(p, CMDprocess_info);
rc = exch(s, p, res);
freepkt(p);
return rc;
}
int
mysql_kill(Sess *s, int id) /* unused at present */
{
Pkt *p;
int rc;
p = mkpkt(s);
p8(p, CMDprocess_kill);
p32(p, id);
rc = exch(s, p, nil);
freepkt(p);
return rc;
}
int
mysql_create_db(Sess *s, char *db) /* unused & depricated */
{
Pkt *p;
int rc;
p = mkpkt(s);
p8(p, CMDcreate_db);
pstr(p, db);
rc = exch(s, p, nil);
freepkt(p);
return rc;
}
int
mysql_drop_db(Sess *s, char *db) /* unused & depricated */
{
Pkt *p;
int rc;
p = mkpkt(s);
p8(p, CMDdrop_db);
pstr(p, db);
rc = exch(s, p, nil);
freepkt(p);
return rc;
}
int
mysql_set_option(Sess *s, int opt) /* unused at present */
{
Pkt *p;
int rc;
p = mkpkt(s);
p8(p, CMDset_option);
p16(p, opt);
rc = exch(s, p, nil);
freepkt(p);
return rc;
}
int
mysql_refresh(Sess *s, int opts) /* unused at present */
{
Pkt *p;
int rc;
p = mkpkt(s);
p8(p, CMDrefresh);
p8(p, opts);
rc = exch(s, p, nil);
freepkt(p);
return rc;
}
int
mysql_shutdown(Sess *s, int level) /* unused at present */
{
Pkt *p;
int rc;
p = mkpkt(s);
p8(p, CMDprocess_kill);
p8(p, level);
rc = exch(s, p, nil);
freepkt(p);
return rc;
}
char *
mysql_stats(Sess *s)
{
Pkt *p;
int seq;
char *str;
p = mkpkt(s);
p8(p, CMDstatistics);
seq = 0;
lock(&s->lock);
if(putpkt(s, p, &seq) == -1){
unlock(&s->lock);
freepkt(p);
return nil;
}
if(getpkt(s, p, &seq) == -1){
unlock(&s->lock);
freepkt(p);
return nil;
}
/* no usual reply reader, just freeform text */
str = gsall(p);
freepkt(p);
unlock(&s->lock);
return str;
}
int
mysql_close(Sess *s) /* unused at present */
{
Pkt *p;
int seq;
seq = 0;
p = mkpkt(s);
p8(p, CMDquit);
lock(&s->lock);
putpkt(s, p, &seq);
/* no reply */
unlock(&s->lock);
freepkt(p);
return 0;
}
|