--- /sys/src/libsec/port/tlshand.c
+++ /sys/src/libsec/port/tlshand.c
@@ -51,6 +51,12 @@ typedef struct Finished{
int n;
} Finished;
+typedef struct HandHash{
+ MD5state md5;
+ SHAstate sha1;
+ SHA2_256state sha2_256;
+} HandHash;
+
typedef struct TlsConnection{
TlsSec *sec; // security management goo
int hand, ctl; // record layer file descriptors
@@ -78,8 +84,7 @@ typedef struct TlsConnection{
int nsecret; // amount of secret data to init keys
// for finished messages
- MD5state hsmd5; // handshake hash
- SHAstate hssha1; // handshake hash
+ HandHash hs; // handshake hash
Finished finished;
} TlsConnection;
@@ -128,15 +133,17 @@ typedef struct TlsSec{
int vers; // final version
// byte generation and handshake checksum
void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
- void (*setFinished)(TlsSec*, MD5state, SHAstate, uchar*, int);
+ void (*setFinished)(TlsSec*, HandHash, uchar*, int);
int nfin;
} TlsSec;
enum {
- TLSVersion = 0x0301,
- SSL3Version = 0x0300,
- ProtocolVersion = 0x0301, // maximum version we speak
+ SSL3Version = 0x0300,
+ TLS10Version = 0x0301,
+ TLS11Version = 0x0302,
+ TLS12Version = 0x0303,
+ ProtocolVersion = TLS12Version, // maximum version we speak
MinProtoVersion = 0x0300, // limits on version we accept
MaxProtoVersion = 0x03ff,
};
@@ -273,7 +280,7 @@ static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uc
static int tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd);
static TlsSec* tlsSecInitc(int cvers, uchar *crandom);
static int tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd);
-static int tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient);
+static int tlsSecFinished(TlsSec *sec, HandHash hs, uchar *fin, int nfin, int isclient);
static void tlsSecOk(TlsSec *sec);
static void tlsSecKill(TlsSec *sec);
static void tlsSecClose(TlsSec *sec);
@@ -283,8 +290,9 @@ static void setSecrets(TlsSec *sec, uchar *kd, int nkd);
static int clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm);
static Bytes *pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
static Bytes *pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm);
-static void tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
-static void sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
+static void tlsSetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient);
+static void tls12SetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient);
+static void sslSetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient);
static void sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
uchar *seed0, int nseed0, uchar *seed1, int nseed1);
static int setVers(TlsSec *sec, int version);
@@ -556,7 +564,7 @@ tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...
msgClear(&m);
/* no CertificateVerify; skip to Finished */
- if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
+ if(tlsSecFinished(c->sec, c->hs, c->finished.verify, c->finished.n, 1) < 0){
tlsError(c, EInternalError, "can't set finished: %r");
goto Err;
}
@@ -578,7 +586,7 @@ tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...
goto Err;
}
- if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
+ if(tlsSecFinished(c->sec, c->hs, c->finished.verify, c->finished.n, 0) < 0){
tlsError(c, EInternalError, "can't set finished: %r");
goto Err;
}
@@ -747,7 +755,7 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...
// Cipherchange must occur immediately before Finished to avoid
// potential hole; see section 4.3 of Wagner Schneier 1996.
- if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
+ if(tlsSecFinished(c->sec, c->hs, c->finished.verify, c->finished.n, 1) < 0){
tlsError(c, EInternalError, "can't set finished 1: %r");
goto Err;
}
@@ -761,7 +769,7 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...
}
msgClear(&m);
- if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
+ if(tlsSecFinished(c->sec, c->hs, c->finished.verify, c->finished.n, 0) < 0){
fprint(2, "tlsClient nepm=%d\n", nepm);
tlsError(c, EInternalError, "can't set finished 0: %r");
goto Err;
@@ -803,6 +811,17 @@ Err:
static uchar sendbuf[9000], *sendp;
+static void
+msgHash(TlsConnection *c, uchar *p, int n)
+{
+ md5(p, n, 0, &c->hs.md5);
+ sha1(p, n, 0, &c->hs.sha1);
+ if(c->version >= TLS12Version)
+ sha2_256(p, n, 0, &c->hs.sha2_256);
+ else
+ memset(&c->hs.sha2_256, 0, sizeof c->hs.sha2_256);
+}
+
static int
msgSend(TlsConnection *c, Msg *m, int act)
{
@@ -914,8 +933,7 @@ msgSend(TlsConnection *c, Msg *m, int act)
// remember hash of Handshake messages
if(m->tag != HHelloRequest) {
- md5(sendp, n, 0, &c->hsmd5);
- sha1(sendp, n, 0, &c->hssha1);
+ msgHash(c, sendp, n);
}
sendp = p;
@@ -991,8 +1009,7 @@ msgRecv(TlsConnection *c, Msg *m)
p = tlsReadN(c, n);
if(p == nil)
return 0;
- md5(p, n, 0, &c->hsmd5);
- sha1(p, n, 0, &c->hssha1);
+ msgHash(c, p, n);
m->tag = HClientHello;
if(n < 22)
goto Short;
@@ -1030,15 +1047,13 @@ msgRecv(TlsConnection *c, Msg *m)
goto Ok;
}
- md5(p, 4, 0, &c->hsmd5);
- sha1(p, 4, 0, &c->hssha1);
+ msgHash(c, p, 4);
p = tlsReadN(c, n);
if(p == nil)
return 0;
- md5(p, n, 0, &c->hsmd5);
- sha1(p, n, 0, &c->hssha1);
+ msgHash(c, p, n);
m->tag = type;
@@ -1388,14 +1403,19 @@ setVersion(TlsConnection *c, int version)
return -1;
if(version > c->version)
version = c->version;
- if(version == SSL3Version) {
- c->version = version;
+ switch(version) {
+ case SSL3Version:
c->finished.n = SSL3FinishedLen;
- }else if(version == TLSVersion){
- c->version = version;
+ break;
+ case TLS10Version:
+ case TLS11Version:
+ case TLS12Version:
c->finished.n = TLSFinishedLen;
- }else
+ break;
+ default:
return -1;
+ }
+ c->version = version;
c->verset = 1;
return fprint(c->ctl, "version 0x%x", version);
}
@@ -1721,6 +1741,32 @@ tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, u
}
}
+static void
+tlsPsha2_256(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed, int nseed)
+{
+ uchar ai[SHA2_256dlen], tmp[SHA2_256dlen];
+ int n;
+ SHAstate *s;
+
+ // generate a1
+ s = hmac_sha2_256(label, nlabel, key, nkey, nil, nil);
+ hmac_sha2_256(seed, nseed, key, nkey, ai, s);
+
+ while(nbuf > 0) {
+ s = hmac_sha2_256(ai, SHA2_256dlen, key, nkey, nil, nil);
+ s = hmac_sha2_256(label, nlabel, key, nkey, nil, s);
+ hmac_sha2_256(seed, nseed, key, nkey, tmp, s);
+ n = SHA2_256dlen;
+ if(n > nbuf)
+ n = nbuf;
+ memmove(buf, tmp, n);
+ buf += n;
+ nbuf -= n;
+ hmac_sha2_256(ai, SHA2_256dlen, key, nkey, tmp, nil);
+ memmove(ai, tmp, SHA2_256dlen);
+ }
+}
+
// fill buf with md5(args)^sha1(args)
static void
tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
@@ -1735,6 +1781,17 @@ tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, in
tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
}
+void
+tls12PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
+{
+ uchar seed[2*RandomSize];
+ int nlabel = strlen(label);
+
+ memmove(seed, seed0, nseed0);
+ memmove(seed+nseed0, seed1, nseed1);
+ tlsPsha2_256(buf, nbuf, key, nkey, (uchar*)label, nlabel, seed, nseed0+nseed1);
+}
+
/*
* for setting server session id's
*/
@@ -1833,16 +1890,16 @@ Err:
}
static int
-tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient)
+tlsSecFinished(TlsSec *sec, HandHash hs, uchar *fin, int nfin, int isclient)
{
if(sec->nfin != nfin){
sec->ok = -1;
werrstr("invalid finished exchange");
return -1;
}
- md5.malloced = 0;
- sha1.malloced = 0;
- (*sec->setFinished)(sec, md5, sha1, fin, isclient);
+ hs.md5.malloced = 0;
+ hs.sha1.malloced = 0;
+ (*sec->setFinished)(sec, hs, fin, isclient);
return 1;
}
@@ -1875,15 +1932,24 @@ tlsSecClose(TlsSec *sec)
static int
setVers(TlsSec *sec, int v)
{
- if(v == SSL3Version){
+ switch(v){
+ case SSL3Version:
sec->setFinished = sslSetFinished;
sec->nfin = SSL3FinishedLen;
sec->prf = sslPRF;
- }else if(v == TLSVersion){
+ break;
+ case TLS10Version:
+ case TLS11Version:
sec->setFinished = tlsSetFinished;
sec->nfin = TLSFinishedLen;
sec->prf = tlsPRF;
- }else{
+ break;
+ case TLS12Version:
+ sec->setFinished = tls12SetFinished;
+ sec->nfin = TLSFinishedLen;
+ sec->prf = tls12PRF;
+ break;
+ default:
werrstr("invalid version");
return -1;
}
@@ -1973,7 +2039,7 @@ clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm)
}
static void
-sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
+sslSetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient)
{
DigestState *s;
uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
@@ -1984,21 +2050,21 @@ sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, in
else
label = "SRVR";
- md5((uchar*)label, 4, nil, &hsmd5);
- md5(sec->sec, MasterSecretSize, nil, &hsmd5);
+ md5((uchar*)label, 4, nil, &hs.md5);
+ md5(sec->sec, MasterSecretSize, nil, &hs.md5);
memset(pad, 0x36, 48);
- md5(pad, 48, nil, &hsmd5);
- md5(nil, 0, h0, &hsmd5);
+ md5(pad, 48, nil, &hs.md5);
+ md5(nil, 0, h0, &hs.md5);
memset(pad, 0x5C, 48);
s = md5(sec->sec, MasterSecretSize, nil, nil);
s = md5(pad, 48, nil, s);
md5(h0, MD5dlen, finished, s);
- sha1((uchar*)label, 4, nil, &hssha1);
- sha1(sec->sec, MasterSecretSize, nil, &hssha1);
+ sha1((uchar*)label, 4, nil, &hs.sha1);
+ sha1(sec->sec, MasterSecretSize, nil, &hs.sha1);
memset(pad, 0x36, 40);
- sha1(pad, 40, nil, &hssha1);
- sha1(nil, 0, h1, &hssha1);
+ sha1(pad, 40, nil, &hs.sha1);
+ sha1(nil, 0, h1, &hs.sha1);
memset(pad, 0x5C, 40);
s = sha1(sec->sec, MasterSecretSize, nil, nil);
s = sha1(pad, 40, nil, s);
@@ -2007,20 +2073,37 @@ sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, in
// fill "finished" arg with md5(args)^sha1(args)
static void
-tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
+tlsSetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient)
{
uchar h0[MD5dlen], h1[SHA1dlen];
char *label;
// get current hash value, but allow further messages to be hashed in
- md5(nil, 0, h0, &hsmd5);
- sha1(nil, 0, h1, &hssha1);
+ md5(nil, 0, h0, &hs.md5);
+ sha1(nil, 0, h1, &hs.sha1);
+
+ if(isClient)
+ label = "client finished";
+ else
+ label = "server finished";
+ (*sec->prf)(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
+}
+
+// fill "finished" arg with sha256(args)
+static void
+tls12SetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient)
+{
+ uchar h[SHA2_256dlen];
+ char *label;
+
+ // get current hash value, but allow further messages to be hashed in
+ sha2_256(nil, 0, h, &hs.sha2_256);
if(isClient)
label = "client finished";
else
label = "server finished";
- tlsPRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
+ tlsPsha2_256(finished, TLSFinishedLen, sec->sec, MasterSecretSize, (uchar*)label, strlen(label), h, SHA2_256dlen);
}
static void
|