Plan 9 from Bell Labs’s /usr/web/sources/contrib/yk/dist/9legacy/applied/tls-tlshand12.diff

Copyright © 2021 Plan 9 Foundation.
Distributed under the MIT License.
Download the Plan 9 distribution.


--- /sys/src/libsec/port/tlshand.c
+++ /sys/src/libsec/port/tlshand.c
@@ -51,6 +51,12 @@ typedef struct Finished{
 	int n;
 } Finished;
 
+typedef struct HandHash{
+	MD5state	md5;
+	SHAstate	sha1;
+	SHA2_256state	sha2_256;
+} HandHash;
+
 typedef struct TlsConnection{
 	TlsSec *sec;	// security management goo
 	int hand, ctl;	// record layer file descriptors
@@ -78,8 +84,7 @@ typedef struct TlsConnection{
 	int nsecret;	// amount of secret data to init keys
 
 	// for finished messages
-	MD5state	hsmd5;	// handshake hash
-	SHAstate	hssha1;	// handshake hash
+	HandHash	hs;	// handshake hash
 	Finished	finished;
 } TlsConnection;
 
@@ -128,15 +133,17 @@ typedef struct TlsSec{
 	int vers;			// final version
 	// byte generation and handshake checksum
 	void (*prf)(uchar*, int, uchar*, int, char*, uchar*, int, uchar*, int);
-	void (*setFinished)(TlsSec*, MD5state, SHAstate, uchar*, int);
+	void (*setFinished)(TlsSec*, HandHash, uchar*, int);
 	int nfin;
 } TlsSec;
 
 
 enum {
-	TLSVersion = 0x0301,
-	SSL3Version = 0x0300,
-	ProtocolVersion = 0x0301,	// maximum version we speak
+	SSL3Version  = 0x0300,
+	TLS10Version = 0x0301,
+	TLS11Version = 0x0302,
+	TLS12Version = 0x0303,
+	ProtocolVersion = TLS12Version,	// maximum version we speak
 	MinProtoVersion = 0x0300,	// limits on version we accept
 	MaxProtoVersion	= 0x03ff,
 };
@@ -273,7 +280,7 @@ static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uc
 static int	tlsSecSecrets(TlsSec *sec, int vers, uchar *epm, int nepm, uchar *kd, int nkd);
 static TlsSec*	tlsSecInitc(int cvers, uchar *crandom);
 static int	tlsSecSecretc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers, uchar **epm, int *nepm, uchar *kd, int nkd);
-static int	tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient);
+static int	tlsSecFinished(TlsSec *sec, HandHash hs, uchar *fin, int nfin, int isclient);
 static void	tlsSecOk(TlsSec *sec);
 static void	tlsSecKill(TlsSec *sec);
 static void	tlsSecClose(TlsSec *sec);
@@ -283,8 +290,9 @@ static void	setSecrets(TlsSec *sec, uchar *kd, int nkd);
 static int	clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm);
 static Bytes *pkcs1_encrypt(Bytes* data, RSApub* key, int blocktype);
 static Bytes *pkcs1_decrypt(TlsSec *sec, uchar *epm, int nepm);
-static void	tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
-static void	sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient);
+static void	tlsSetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient);
+static void	tls12SetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient);
+static void	sslSetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient);
 static void	sslPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label,
 			uchar *seed0, int nseed0, uchar *seed1, int nseed1);
 static int setVers(TlsSec *sec, int version);
@@ -556,7 +564,7 @@ tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...
 	msgClear(&m);
 
 	/* no CertificateVerify; skip to Finished */
-	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
+	if(tlsSecFinished(c->sec, c->hs, c->finished.verify, c->finished.n, 1) < 0){
 		tlsError(c, EInternalError, "can't set finished: %r");
 		goto Err;
 	}
@@ -578,7 +586,7 @@ tlsServer2(int ctl, int hand, uchar *cert, int ncert, int (*trace)(char*fmt, ...
 		goto Err;
 	}
 
-	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
+	if(tlsSecFinished(c->sec, c->hs, c->finished.verify, c->finished.n, 0) < 0){
 		tlsError(c, EInternalError, "can't set finished: %r");
 		goto Err;
 	}
@@ -747,7 +755,7 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...
 
 	// Cipherchange must occur immediately before Finished to avoid
 	// potential hole;  see section 4.3 of Wagner Schneier 1996.
-	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 1) < 0){
+	if(tlsSecFinished(c->sec, c->hs, c->finished.verify, c->finished.n, 1) < 0){
 		tlsError(c, EInternalError, "can't set finished 1: %r");
 		goto Err;
 	}
@@ -761,7 +769,7 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, int (*trace)(char*fmt, ...
 	}
 	msgClear(&m);
 
-	if(tlsSecFinished(c->sec, c->hsmd5, c->hssha1, c->finished.verify, c->finished.n, 0) < 0){
+	if(tlsSecFinished(c->sec, c->hs, c->finished.verify, c->finished.n, 0) < 0){
 		fprint(2, "tlsClient nepm=%d\n", nepm);
 		tlsError(c, EInternalError, "can't set finished 0: %r");
 		goto Err;
@@ -803,6 +811,17 @@ Err:
 
 static uchar sendbuf[9000], *sendp;
 
+static void
+msgHash(TlsConnection *c, uchar *p, int n)
+{
+	md5(p, n, 0, &c->hs.md5);
+	sha1(p, n, 0, &c->hs.sha1);
+	if(c->version >= TLS12Version)
+		sha2_256(p, n, 0, &c->hs.sha2_256);
+	else
+		memset(&c->hs.sha2_256, 0, sizeof c->hs.sha2_256);
+}
+
 static int
 msgSend(TlsConnection *c, Msg *m, int act)
 {
@@ -914,8 +933,7 @@ msgSend(TlsConnection *c, Msg *m, int act)
 
 	// remember hash of Handshake messages
 	if(m->tag != HHelloRequest) {
-		md5(sendp, n, 0, &c->hsmd5);
-		sha1(sendp, n, 0, &c->hssha1);
+		msgHash(c, sendp, n);
 	}
 
 	sendp = p;
@@ -991,8 +1009,7 @@ msgRecv(TlsConnection *c, Msg *m)
 		p = tlsReadN(c, n);
 		if(p == nil)
 			return 0;
-		md5(p, n, 0, &c->hsmd5);
-		sha1(p, n, 0, &c->hssha1);
+		msgHash(c, p, n);
 		m->tag = HClientHello;
 		if(n < 22)
 			goto Short;
@@ -1030,15 +1047,13 @@ msgRecv(TlsConnection *c, Msg *m)
 		goto Ok;
 	}
 
-	md5(p, 4, 0, &c->hsmd5);
-	sha1(p, 4, 0, &c->hssha1);
+	msgHash(c, p, 4);
 
 	p = tlsReadN(c, n);
 	if(p == nil)
 		return 0;
 
-	md5(p, n, 0, &c->hsmd5);
-	sha1(p, n, 0, &c->hssha1);
+	msgHash(c, p, n);
 
 	m->tag = type;
 
@@ -1388,14 +1403,19 @@ setVersion(TlsConnection *c, int version)
 		return -1;
 	if(version > c->version)
 		version = c->version;
-	if(version == SSL3Version) {
-		c->version = version;
+	switch(version) {
+	case SSL3Version:
 		c->finished.n = SSL3FinishedLen;
-	}else if(version == TLSVersion){
-		c->version = version;
+		break;
+	case TLS10Version:
+	case TLS11Version:
+	case TLS12Version:
 		c->finished.n = TLSFinishedLen;
-	}else
+		break;
+	default:
 		return -1;
+	}
+	c->version = version;
 	c->verset = 1;
 	return fprint(c->ctl, "version 0x%x", version);
 }
@@ -1721,6 +1741,32 @@ tlsPsha1(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, u
 	}
 }
 
+static void
+tlsPsha2_256(uchar *buf, int nbuf, uchar *key, int nkey, uchar *label, int nlabel, uchar *seed, int nseed)
+{
+	uchar ai[SHA2_256dlen], tmp[SHA2_256dlen];
+	int n;
+	SHAstate *s;
+
+	// generate a1
+	s = hmac_sha2_256(label, nlabel, key, nkey, nil, nil);
+	hmac_sha2_256(seed, nseed, key, nkey, ai, s);
+
+	while(nbuf > 0) {
+		s = hmac_sha2_256(ai, SHA2_256dlen, key, nkey, nil, nil);
+		s = hmac_sha2_256(label, nlabel, key, nkey, nil, s);
+		hmac_sha2_256(seed, nseed, key, nkey, tmp, s);
+		n = SHA2_256dlen;
+		if(n > nbuf)
+			n = nbuf;
+		memmove(buf, tmp, n);
+		buf += n;
+		nbuf -= n;
+		hmac_sha2_256(ai, SHA2_256dlen, key, nkey, tmp, nil);
+		memmove(ai, tmp, SHA2_256dlen);
+	}
+}
+
 // fill buf with md5(args)^sha1(args)
 static void
 tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
@@ -1735,6 +1781,17 @@ tlsPRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, in
 	tlsPsha1(buf, nbuf, key+nkey-n, n, (uchar*)label, nlabel, seed0, nseed0, seed1, nseed1);
 }
 
+void
+tls12PRF(uchar *buf, int nbuf, uchar *key, int nkey, char *label, uchar *seed0, int nseed0, uchar *seed1, int nseed1)
+{
+	uchar seed[2*RandomSize];
+	int nlabel = strlen(label);
+
+	memmove(seed, seed0, nseed0);
+	memmove(seed+nseed0, seed1, nseed1);
+	tlsPsha2_256(buf, nbuf, key, nkey, (uchar*)label, nlabel, seed, nseed0+nseed1);
+}
+
 /*
  * for setting server session id's
  */
@@ -1833,16 +1890,16 @@ Err:
 }
 
 static int
-tlsSecFinished(TlsSec *sec, MD5state md5, SHAstate sha1, uchar *fin, int nfin, int isclient)
+tlsSecFinished(TlsSec *sec, HandHash hs, uchar *fin, int nfin, int isclient)
 {
 	if(sec->nfin != nfin){
 		sec->ok = -1;
 		werrstr("invalid finished exchange");
 		return -1;
 	}
-	md5.malloced = 0;
-	sha1.malloced = 0;
-	(*sec->setFinished)(sec, md5, sha1, fin, isclient);
+	hs.md5.malloced = 0;
+	hs.sha1.malloced = 0;
+	(*sec->setFinished)(sec, hs, fin, isclient);
 	return 1;
 }
 
@@ -1875,15 +1932,24 @@ tlsSecClose(TlsSec *sec)
 static int
 setVers(TlsSec *sec, int v)
 {
-	if(v == SSL3Version){
+	switch(v){
+	case SSL3Version:
 		sec->setFinished = sslSetFinished;
 		sec->nfin = SSL3FinishedLen;
 		sec->prf = sslPRF;
-	}else if(v == TLSVersion){
+		break;
+	case TLS10Version:
+	case TLS11Version:
 		sec->setFinished = tlsSetFinished;
 		sec->nfin = TLSFinishedLen;
 		sec->prf = tlsPRF;
-	}else{
+		break;
+	case TLS12Version:
+		sec->setFinished = tls12SetFinished;
+		sec->nfin = TLSFinishedLen;
+		sec->prf = tls12PRF;
+		break;
+	default:
 		werrstr("invalid version");
 		return -1;
 	}
@@ -1973,7 +2039,7 @@ clientMasterSecret(TlsSec *sec, RSApub *pub, uchar **epm, int *nepm)
 }
 
 static void
-sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
+sslSetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient)
 {
 	DigestState *s;
 	uchar h0[MD5dlen], h1[SHA1dlen], pad[48];
@@ -1984,21 +2050,21 @@ sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, in
 	else
 		label = "SRVR";
 
-	md5((uchar*)label, 4, nil, &hsmd5);
-	md5(sec->sec, MasterSecretSize, nil, &hsmd5);
+	md5((uchar*)label, 4, nil, &hs.md5);
+	md5(sec->sec, MasterSecretSize, nil, &hs.md5);
 	memset(pad, 0x36, 48);
-	md5(pad, 48, nil, &hsmd5);
-	md5(nil, 0, h0, &hsmd5);
+	md5(pad, 48, nil, &hs.md5);
+	md5(nil, 0, h0, &hs.md5);
 	memset(pad, 0x5C, 48);
 	s = md5(sec->sec, MasterSecretSize, nil, nil);
 	s = md5(pad, 48, nil, s);
 	md5(h0, MD5dlen, finished, s);
 
-	sha1((uchar*)label, 4, nil, &hssha1);
-	sha1(sec->sec, MasterSecretSize, nil, &hssha1);
+	sha1((uchar*)label, 4, nil, &hs.sha1);
+	sha1(sec->sec, MasterSecretSize, nil, &hs.sha1);
 	memset(pad, 0x36, 40);
-	sha1(pad, 40, nil, &hssha1);
-	sha1(nil, 0, h1, &hssha1);
+	sha1(pad, 40, nil, &hs.sha1);
+	sha1(nil, 0, h1, &hs.sha1);
 	memset(pad, 0x5C, 40);
 	s = sha1(sec->sec, MasterSecretSize, nil, nil);
 	s = sha1(pad, 40, nil, s);
@@ -2007,20 +2073,37 @@ sslSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, in
 
 // fill "finished" arg with md5(args)^sha1(args)
 static void
-tlsSetFinished(TlsSec *sec, MD5state hsmd5, SHAstate hssha1, uchar *finished, int isClient)
+tlsSetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient)
 {
 	uchar h0[MD5dlen], h1[SHA1dlen];
 	char *label;
 
 	// get current hash value, but allow further messages to be hashed in
-	md5(nil, 0, h0, &hsmd5);
-	sha1(nil, 0, h1, &hssha1);
+	md5(nil, 0, h0, &hs.md5);
+	sha1(nil, 0, h1, &hs.sha1);
+
+	if(isClient)
+		label = "client finished";
+	else
+		label = "server finished";
+	(*sec->prf)(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
+}
+
+// fill "finished" arg with sha256(args)
+static void
+tls12SetFinished(TlsSec *sec, HandHash hs, uchar *finished, int isClient)
+{
+	uchar h[SHA2_256dlen];
+	char *label;
+
+	// get current hash value, but allow further messages to be hashed in
+	sha2_256(nil, 0, h, &hs.sha2_256);
 
 	if(isClient)
 		label = "client finished";
 	else
 		label = "server finished";
-	tlsPRF(finished, TLSFinishedLen, sec->sec, MasterSecretSize, label, h0, MD5dlen, h1, SHA1dlen);
+	tlsPsha2_256(finished, TLSFinishedLen, sec->sec, MasterSecretSize, (uchar*)label, strlen(label), h, SHA2_256dlen);
 }
 
 static void

Bell Labs OSI certified Powered by Plan 9

(Return to Plan 9 Home Page)

Copyright © 2021 Plan 9 Foundation. All Rights Reserved.
Comments to [email protected].