From 354bb40e75d94466e91fe6960523612c9d17ccfb Mon Sep 17 00:00:00 2001 From: Karen Arutyunov Date: Thu, 2 Nov 2017 23:11:29 +0300 Subject: Add implementation --- mysql/extra/yassl/src/handshake.cpp | 1190 +++++++++++++++++++++++++++++++++++ 1 file changed, 1190 insertions(+) create mode 100644 mysql/extra/yassl/src/handshake.cpp (limited to 'mysql/extra/yassl/src/handshake.cpp') diff --git a/mysql/extra/yassl/src/handshake.cpp b/mysql/extra/yassl/src/handshake.cpp new file mode 100644 index 0000000..91cc407 --- /dev/null +++ b/mysql/extra/yassl/src/handshake.cpp @@ -0,0 +1,1190 @@ +/* + Copyright (c) 2005, 2014, Oracle and/or its affiliates. All rights reserved. + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; version 2 of the License. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program; see the file COPYING. If not, write to the + Free Software Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, + MA 02110-1301 USA. +*/ + + +/* The handshake source implements functions for creating and reading + * the various handshake messages. + */ + + + +#include "runtime.hpp" +#include "handshake.hpp" +#include "yassl_int.hpp" + + +namespace yaSSL { + + + +// Build a client hello message from cipher suites and compression method +void buildClientHello(SSL& ssl, ClientHello& hello) +{ + // store for pre master secret + ssl.useSecurity().use_connection().chVersion_ = hello.client_version_; + + ssl.getCrypto().get_random().Fill(hello.random_, RAN_LEN); + if (ssl.getSecurity().get_resuming()) { + hello.id_len_ = ID_LEN; + memcpy(hello.session_id_, ssl.getSecurity().get_resume().GetID(), + ID_LEN); + } + else + hello.id_len_ = 0; + hello.suite_len_ = ssl.getSecurity().get_parms().suites_size_; + memcpy(hello.cipher_suites_, ssl.getSecurity().get_parms().suites_, + hello.suite_len_); + hello.comp_len_ = 1; + + hello.set_length(sizeof(ProtocolVersion) + + RAN_LEN + + hello.id_len_ + sizeof(hello.id_len_) + + hello.suite_len_ + sizeof(hello.suite_len_) + + hello.comp_len_ + sizeof(hello.comp_len_)); +} + + +// Build a server hello message +void buildServerHello(SSL& ssl, ServerHello& hello) +{ + if (ssl.getSecurity().get_resuming()) { + memcpy(hello.random_,ssl.getSecurity().get_connection().server_random_, + RAN_LEN); + memcpy(hello.session_id_, ssl.getSecurity().get_resume().GetID(), + ID_LEN); + } + else { + ssl.getCrypto().get_random().Fill(hello.random_, RAN_LEN); + ssl.getCrypto().get_random().Fill(hello.session_id_, ID_LEN); + } + hello.id_len_ = ID_LEN; + ssl.set_sessionID(hello.session_id_); + + hello.cipher_suite_[0] = ssl.getSecurity().get_parms().suite_[0]; + hello.cipher_suite_[1] = ssl.getSecurity().get_parms().suite_[1]; + hello.compression_method_ = hello.compression_method_; + + hello.set_length(sizeof(ProtocolVersion) + RAN_LEN + ID_LEN + + sizeof(hello.id_len_) + SUITE_LEN + SIZEOF_ENUM); +} + + +// add handshake from buffer into md5 and sha hashes, use handshake header +void hashHandShake(SSL& ssl, const input_buffer& input, uint sz) +{ + const opaque* buffer = input.get_buffer() + input.get_current() - + HANDSHAKE_HEADER; + sz += HANDSHAKE_HEADER; + ssl.useHashes().use_MD5().update(buffer, sz); + ssl.useHashes().use_SHA().update(buffer, sz); +} + + +// locals +namespace { + +// Write a plaintext record to buffer +void buildOutput(output_buffer& buffer, const RecordLayerHeader& rlHdr, + const Message& msg) +{ + buffer.allocate(RECORD_HEADER + rlHdr.length_); + buffer << rlHdr << msg; +} + + +// Write a plaintext record to buffer +void buildOutput(output_buffer& buffer, const RecordLayerHeader& rlHdr, + const HandShakeHeader& hsHdr, const HandShakeBase& shake) +{ + buffer.allocate(RECORD_HEADER + rlHdr.length_); + buffer << rlHdr << hsHdr << shake; +} + + +// Build Record Layer header for Message without handshake header +void buildHeader(SSL& ssl, RecordLayerHeader& rlHeader, const Message& msg) +{ + ProtocolVersion pv = ssl.getSecurity().get_connection().version_; + rlHeader.type_ = msg.get_type(); + rlHeader.version_.major_ = pv.major_; + rlHeader.version_.minor_ = pv.minor_; + rlHeader.length_ = msg.get_length(); +} + + +// Build HandShake and RecordLayer Headers for handshake output +void buildHeaders(SSL& ssl, HandShakeHeader& hsHeader, + RecordLayerHeader& rlHeader, const HandShakeBase& shake) +{ + int sz = shake.get_length(); + + hsHeader.set_type(shake.get_type()); + hsHeader.set_length(sz); + + ProtocolVersion pv = ssl.getSecurity().get_connection().version_; + rlHeader.type_ = handshake; + rlHeader.version_.major_ = pv.major_; + rlHeader.version_.minor_ = pv.minor_; + rlHeader.length_ = sz + HANDSHAKE_HEADER; +} + + +// add handshake from buffer into md5 and sha hashes, exclude record header +void hashHandShake(SSL& ssl, const output_buffer& output, bool removeIV = false) +{ + uint sz = output.get_size() - RECORD_HEADER; + + const opaque* buffer = output.get_buffer() + RECORD_HEADER; + + if (removeIV) { // TLSv1_1 IV + uint blockSz = ssl.getCrypto().get_cipher().get_blockSize(); + sz -= blockSz; + buffer += blockSz; + } + + ssl.useHashes().use_MD5().update(buffer, sz); + ssl.useHashes().use_SHA().update(buffer, sz); +} + + +// calculate MD5 hash for finished +void buildMD5(SSL& ssl, Finished& fin, const opaque* sender) +{ + + opaque md5_result[MD5_LEN]; + opaque md5_inner[SIZEOF_SENDER + SECRET_LEN + PAD_MD5]; + opaque md5_outer[SECRET_LEN + PAD_MD5 + MD5_LEN]; + + const opaque* master_secret = + ssl.getSecurity().get_connection().master_secret_; + + // make md5 inner + memcpy(md5_inner, sender, SIZEOF_SENDER); + memcpy(&md5_inner[SIZEOF_SENDER], master_secret, SECRET_LEN); + memcpy(&md5_inner[SIZEOF_SENDER + SECRET_LEN], PAD1, PAD_MD5); + + ssl.useHashes().use_MD5().get_digest(md5_result, md5_inner, + sizeof(md5_inner)); + + // make md5 outer + memcpy(md5_outer, master_secret, SECRET_LEN); + memcpy(&md5_outer[SECRET_LEN], PAD2, PAD_MD5); + memcpy(&md5_outer[SECRET_LEN + PAD_MD5], md5_result, MD5_LEN); + + ssl.useHashes().use_MD5().get_digest(fin.set_md5(), md5_outer, + sizeof(md5_outer)); +} + + +// calculate SHA hash for finished +void buildSHA(SSL& ssl, Finished& fin, const opaque* sender) +{ + + opaque sha_result[SHA_LEN]; + opaque sha_inner[SIZEOF_SENDER + SECRET_LEN + PAD_SHA]; + opaque sha_outer[SECRET_LEN + PAD_SHA + SHA_LEN]; + + const opaque* master_secret = + ssl.getSecurity().get_connection().master_secret_; + + // make sha inner + memcpy(sha_inner, sender, SIZEOF_SENDER); + memcpy(&sha_inner[SIZEOF_SENDER], master_secret, SECRET_LEN); + memcpy(&sha_inner[SIZEOF_SENDER + SECRET_LEN], PAD1, PAD_SHA); + + ssl.useHashes().use_SHA().get_digest(sha_result, sha_inner, + sizeof(sha_inner)); + + // make sha outer + memcpy(sha_outer, master_secret, SECRET_LEN); + memcpy(&sha_outer[SECRET_LEN], PAD2, PAD_SHA); + memcpy(&sha_outer[SECRET_LEN + PAD_SHA], sha_result, SHA_LEN); + + ssl.useHashes().use_SHA().get_digest(fin.set_sha(), sha_outer, + sizeof(sha_outer)); +} + + +// sanity checks on encrypted message size +static int sanity_check_message(SSL& ssl, uint msgSz) +{ + uint minSz = 0; + + if (ssl.getSecurity().get_parms().cipher_type_ == block) { + uint blockSz = ssl.getCrypto().get_cipher().get_blockSize(); + if (msgSz % blockSz) + return -1; + + minSz = ssl.getSecurity().get_parms().hash_size_ + 1; // pad byte too + if (blockSz > minSz) + minSz = blockSz; + + if (ssl.isTLSv1_1()) + minSz += blockSz; // explicit IV + } + else { // stream + minSz = ssl.getSecurity().get_parms().hash_size_; + } + + if (msgSz < minSz) + return -1; + + return 0; +} + + +// decrypt input message in place, store size in case needed later +void decrypt_message(SSL& ssl, input_buffer& input, uint sz) +{ + input_buffer plain(sz); + opaque* cipher = input.get_buffer() + input.get_current(); + + if (sanity_check_message(ssl, sz) != 0) { + ssl.SetError(sanityCipher_error); + return; + } + + ssl.useCrypto().use_cipher().decrypt(plain.get_buffer(), cipher, sz); + memcpy(cipher, plain.get_buffer(), sz); + ssl.useSecurity().use_parms().encrypt_size_ = sz; + + if (ssl.isTLSv1_1()) // IV + input.set_current(input.get_current() + + ssl.getCrypto().get_cipher().get_blockSize()); +} + + +// output operator for input_buffer +output_buffer& operator<<(output_buffer& output, const input_buffer& input) +{ + output.write(input.get_buffer(), input.get_size()); + return output; +} + + +// write headers, handshake hash, mac, pad, and encrypt +void cipherFinished(SSL& ssl, Finished& fin, output_buffer& output) +{ + uint digestSz = ssl.getCrypto().get_digest().get_digestSize(); + uint finishedSz = ssl.isTLS() ? TLS_FINISHED_SZ : FINISHED_SZ; + uint sz = RECORD_HEADER + HANDSHAKE_HEADER + finishedSz + digestSz; + uint pad = 0; + uint blockSz = ssl.getCrypto().get_cipher().get_blockSize(); + + if (ssl.getSecurity().get_parms().cipher_type_ == block) { + if (ssl.isTLSv1_1()) + sz += blockSz; // IV + sz += 1; // pad byte + pad = (sz - RECORD_HEADER) % blockSz; + pad = blockSz - pad; + sz += pad; + } + + RecordLayerHeader rlHeader; + HandShakeHeader hsHeader; + buildHeaders(ssl, hsHeader, rlHeader, fin); + rlHeader.length_ = sz - RECORD_HEADER; // record header includes mac + // and pad, hanshake doesn't + input_buffer iv; + if (ssl.isTLSv1_1() && ssl.getSecurity().get_parms().cipher_type_== block){ + iv.allocate(blockSz); + ssl.getCrypto().get_random().Fill(iv.get_buffer(), blockSz); + iv.add_size(blockSz); + } + uint ivSz = iv.get_size(); + output.allocate(sz); + output << rlHeader << iv << hsHeader << fin; + + hashHandShake(ssl, output, ssl.isTLSv1_1() ? true : false); + opaque digest[SHA_LEN]; // max size + if (ssl.isTLS()) + TLS_hmac(ssl, digest, output.get_buffer() + RECORD_HEADER + ivSz, + output.get_size() - RECORD_HEADER - ivSz, handshake); + else + hmac(ssl, digest, output.get_buffer() + RECORD_HEADER, + output.get_size() - RECORD_HEADER, handshake); + output.write(digest, digestSz); + + if (ssl.getSecurity().get_parms().cipher_type_ == block) + for (uint i = 0; i <= pad; i++) output[AUTO] = pad; // pad byte gets + // pad value too + input_buffer cipher(rlHeader.length_); + ssl.useCrypto().use_cipher().encrypt(cipher.get_buffer(), + output.get_buffer() + RECORD_HEADER, output.get_size() - RECORD_HEADER); + output.set_current(RECORD_HEADER); + output.write(cipher.get_buffer(), cipher.get_capacity()); +} + + +// build an encrypted data or alert message for output +void buildMessage(SSL& ssl, output_buffer& output, const Message& msg) +{ + uint digestSz = ssl.getCrypto().get_digest().get_digestSize(); + uint sz = RECORD_HEADER + msg.get_length() + digestSz; + uint pad = 0; + uint blockSz = ssl.getCrypto().get_cipher().get_blockSize(); + + if (ssl.getSecurity().get_parms().cipher_type_ == block) { + if (ssl.isTLSv1_1()) // IV + sz += blockSz; + sz += 1; // pad byte + pad = (sz - RECORD_HEADER) % blockSz; + pad = blockSz - pad; + sz += pad; + } + + RecordLayerHeader rlHeader; + buildHeader(ssl, rlHeader, msg); + rlHeader.length_ = sz - RECORD_HEADER; // record header includes mac + // and pad, hanshake doesn't + input_buffer iv; + if (ssl.isTLSv1_1() && ssl.getSecurity().get_parms().cipher_type_== block){ + iv.allocate(blockSz); + ssl.getCrypto().get_random().Fill(iv.get_buffer(), blockSz); + iv.add_size(blockSz); + } + + uint ivSz = iv.get_size(); + output.allocate(sz); + output << rlHeader << iv << msg; + + opaque digest[SHA_LEN]; // max size + if (ssl.isTLS()) + TLS_hmac(ssl, digest, output.get_buffer() + RECORD_HEADER + ivSz, + output.get_size() - RECORD_HEADER - ivSz, msg.get_type()); + else + hmac(ssl, digest, output.get_buffer() + RECORD_HEADER, + output.get_size() - RECORD_HEADER, msg.get_type()); + output.write(digest, digestSz); + + if (ssl.getSecurity().get_parms().cipher_type_ == block) + for (uint i = 0; i <= pad; i++) output[AUTO] = pad; // pad byte gets + // pad value too + input_buffer cipher(rlHeader.length_); + ssl.useCrypto().use_cipher().encrypt(cipher.get_buffer(), + output.get_buffer() + RECORD_HEADER, output.get_size() - RECORD_HEADER); + output.set_current(RECORD_HEADER); + output.write(cipher.get_buffer(), cipher.get_capacity()); +} + + +// build alert message +void buildAlert(SSL& ssl, output_buffer& output, const Alert& alert) +{ + if (ssl.getSecurity().get_parms().pending_ == false) // encrypted + buildMessage(ssl, output, alert); + else { + RecordLayerHeader rlHeader; + buildHeader(ssl, rlHeader, alert); + buildOutput(output, rlHeader, alert); + } +} + + +// build TLS finished message +void buildFinishedTLS(SSL& ssl, Finished& fin, const opaque* sender) +{ + opaque handshake_hash[FINISHED_SZ]; + + ssl.useHashes().use_MD5().get_digest(handshake_hash); + ssl.useHashes().use_SHA().get_digest(&handshake_hash[MD5_LEN]); + + const opaque* side; + if ( strncmp((const char*)sender, (const char*)client, SIZEOF_SENDER) == 0) + side = tls_client; + else + side = tls_server; + + PRF(fin.set_md5(), TLS_FINISHED_SZ, + ssl.getSecurity().get_connection().master_secret_, SECRET_LEN, + side, FINISHED_LABEL_SZ, + handshake_hash, FINISHED_SZ); + + fin.set_length(TLS_FINISHED_SZ); // shorter length for TLS +} + + +// compute p_hash for MD5 or SHA-1 for TLSv1 PRF +void p_hash(output_buffer& result, const output_buffer& secret, + const output_buffer& seed, MACAlgorithm hash) +{ + uint len = hash == md5 ? MD5_LEN : SHA_LEN; + uint times = result.get_capacity() / len; + uint lastLen = result.get_capacity() % len; + opaque previous[SHA_LEN]; // max size + opaque current[SHA_LEN]; // max size + mySTL::auto_ptr hmac; + + if (lastLen) times += 1; + + if (hash == md5) + hmac.reset(NEW_YS HMAC_MD5(secret.get_buffer(), secret.get_size())); + else + hmac.reset(NEW_YS HMAC_SHA(secret.get_buffer(), secret.get_size())); + // A0 = seed + hmac->get_digest(previous, seed.get_buffer(), seed.get_size());// A1 + uint lastTime = times - 1; + + for (uint i = 0; i < times; i++) { + hmac->update(previous, len); + hmac->get_digest(current, seed.get_buffer(), seed.get_size()); + + if (lastLen && (i == lastTime)) + result.write(current, lastLen); + else { + result.write(current, len); + //memcpy(previous, current, len); + hmac->get_digest(previous, previous, len); + } + } +} + + +// calculate XOR for TLSv1 PRF +void get_xor(byte *digest, uint digLen, output_buffer& md5, + output_buffer& sha) +{ + for (uint i = 0; i < digLen; i++) + digest[i] = md5[AUTO] ^ sha[AUTO]; +} + + +// build MD5 part of certificate verify +void buildMD5_CertVerify(SSL& ssl, byte* digest) +{ + opaque md5_result[MD5_LEN]; + opaque md5_inner[SECRET_LEN + PAD_MD5]; + opaque md5_outer[SECRET_LEN + PAD_MD5 + MD5_LEN]; + + const opaque* master_secret = + ssl.getSecurity().get_connection().master_secret_; + + // make md5 inner + memcpy(md5_inner, master_secret, SECRET_LEN); + memcpy(&md5_inner[SECRET_LEN], PAD1, PAD_MD5); + + ssl.useHashes().use_MD5().get_digest(md5_result, md5_inner, + sizeof(md5_inner)); + + // make md5 outer + memcpy(md5_outer, master_secret, SECRET_LEN); + memcpy(&md5_outer[SECRET_LEN], PAD2, PAD_MD5); + memcpy(&md5_outer[SECRET_LEN + PAD_MD5], md5_result, MD5_LEN); + + ssl.useHashes().use_MD5().get_digest(digest, md5_outer, sizeof(md5_outer)); +} + + +// build SHA part of certificate verify +void buildSHA_CertVerify(SSL& ssl, byte* digest) +{ + opaque sha_result[SHA_LEN]; + opaque sha_inner[SECRET_LEN + PAD_SHA]; + opaque sha_outer[SECRET_LEN + PAD_SHA + SHA_LEN]; + + const opaque* master_secret = + ssl.getSecurity().get_connection().master_secret_; + + // make sha inner + memcpy(sha_inner, master_secret, SECRET_LEN); + memcpy(&sha_inner[SECRET_LEN], PAD1, PAD_SHA); + + ssl.useHashes().use_SHA().get_digest(sha_result, sha_inner, + sizeof(sha_inner)); + + // make sha outer + memcpy(sha_outer, master_secret, SECRET_LEN); + memcpy(&sha_outer[SECRET_LEN], PAD2, PAD_SHA); + memcpy(&sha_outer[SECRET_LEN + PAD_SHA], sha_result, SHA_LEN); + + ssl.useHashes().use_SHA().get_digest(digest, sha_outer, sizeof(sha_outer)); +} + + +} // namespace for locals + + +// some clients still send sslv2 client hello +void ProcessOldClientHello(input_buffer& input, SSL& ssl) +{ + if (input.get_error() || input.get_remaining() < 2) { + ssl.SetError(bad_input); + return; + } + byte b0 = input[AUTO]; + byte b1 = input[AUTO]; + + uint16 sz = ((b0 & 0x7f) << 8) | b1; + + if (sz > input.get_remaining()) { + ssl.SetError(bad_input); + return; + } + + // hashHandShake manually + const opaque* buffer = input.get_buffer() + input.get_current(); + ssl.useHashes().use_MD5().update(buffer, sz); + ssl.useHashes().use_SHA().update(buffer, sz); + + b1 = input[AUTO]; // does this value mean client_hello? + + ClientHello ch; + ch.client_version_.major_ = input[AUTO]; + ch.client_version_.minor_ = input[AUTO]; + + byte len[2]; + + len[0] = input[AUTO]; + len[1] = input[AUTO]; + ato16(len, ch.suite_len_); + + len[0] = input[AUTO]; + len[1] = input[AUTO]; + uint16 sessionLen; + ato16(len, sessionLen); + ch.id_len_ = sessionLen; + + len[0] = input[AUTO]; + len[1] = input[AUTO]; + uint16 randomLen; + ato16(len, randomLen); + + if (input.get_error() || ch.suite_len_ > MAX_SUITE_SZ || + ch.suite_len_ > input.get_remaining() || + sessionLen > ID_LEN || randomLen > RAN_LEN) { + ssl.SetError(bad_input); + return; + } + + int j = 0; + for (uint16 i = 0; i < ch.suite_len_; i += 3) { + byte first = input[AUTO]; + if (first) // sslv2 type + input.read(len, SUITE_LEN); // skip + else { + input.read(&ch.cipher_suites_[j], SUITE_LEN); + j += SUITE_LEN; + } + } + ch.suite_len_ = j; + + if (ch.id_len_) + input.read(ch.session_id_, ch.id_len_); // id_len_ from sessionLen + + if (randomLen < RAN_LEN) + memset(ch.random_, 0, RAN_LEN - randomLen); + input.read(&ch.random_[RAN_LEN - randomLen], randomLen); + + ch.Process(input, ssl); +} + + +// Build a finished message, see 7.6.9 +void buildFinished(SSL& ssl, Finished& fin, const opaque* sender) +{ + // store current states, building requires get_digest which resets state + MD5 md5(ssl.getHashes().get_MD5()); + SHA sha(ssl.getHashes().get_SHA()); + + if (ssl.isTLS()) + buildFinishedTLS(ssl, fin, sender); + else { + buildMD5(ssl, fin, sender); + buildSHA(ssl, fin, sender); + } + + // restore + ssl.useHashes().use_MD5() = md5; + ssl.useHashes().use_SHA() = sha; +} + + +/* compute SSLv3 HMAC into digest see + * buffer is of sz size and includes HandShake Header but not a Record Header + * verify means to check peers hmac +*/ +void hmac(SSL& ssl, byte* digest, const byte* buffer, uint sz, + ContentType content, bool verify) +{ + Digest& mac = ssl.useCrypto().use_digest(); + opaque inner[SHA_LEN + PAD_MD5 + SEQ_SZ + SIZEOF_ENUM + LENGTH_SZ]; + opaque outer[SHA_LEN + PAD_MD5 + SHA_LEN]; + opaque result[SHA_LEN]; // max possible sizes + uint digestSz = mac.get_digestSize(); // actual sizes + uint padSz = mac.get_padSize(); + uint innerSz = digestSz + padSz + SEQ_SZ + SIZEOF_ENUM + LENGTH_SZ; + uint outerSz = digestSz + padSz + digestSz; + + // data + const opaque* mac_secret = ssl.get_macSecret(verify); + opaque seq[SEQ_SZ] = { 0x00, 0x00, 0x00, 0x00 }; + opaque length[LENGTH_SZ]; + c16toa(sz, length); + c32toa(ssl.get_SEQIncrement(verify), &seq[sizeof(uint32)]); + + // make inner + memcpy(inner, mac_secret, digestSz); + memcpy(&inner[digestSz], PAD1, padSz); + memcpy(&inner[digestSz + padSz], seq, SEQ_SZ); + inner[digestSz + padSz + SEQ_SZ] = content; + memcpy(&inner[digestSz + padSz + SEQ_SZ + SIZEOF_ENUM], length, LENGTH_SZ); + + mac.update(inner, innerSz); + mac.get_digest(result, buffer, sz); // append content buffer + + // make outer + memcpy(outer, mac_secret, digestSz); + memcpy(&outer[digestSz], PAD2, padSz); + memcpy(&outer[digestSz + padSz], result, digestSz); + + mac.get_digest(digest, outer, outerSz); +} + + +// TLS type HAMC +void TLS_hmac(SSL& ssl, byte* digest, const byte* buffer, uint sz, + ContentType content, bool verify) +{ + mySTL::auto_ptr hmac; + opaque seq[SEQ_SZ] = { 0x00, 0x00, 0x00, 0x00 }; + opaque length[LENGTH_SZ]; + opaque inner[SIZEOF_ENUM + VERSION_SZ + LENGTH_SZ]; // type + version + len + + c16toa(sz, length); + c32toa(ssl.get_SEQIncrement(verify), &seq[sizeof(uint32)]); + + MACAlgorithm algo = ssl.getSecurity().get_parms().mac_algorithm_; + + if (algo == sha) + hmac.reset(NEW_YS HMAC_SHA(ssl.get_macSecret(verify), SHA_LEN)); + else if (algo == rmd) + hmac.reset(NEW_YS HMAC_RMD(ssl.get_macSecret(verify), RMD_LEN)); + else + hmac.reset(NEW_YS HMAC_MD5(ssl.get_macSecret(verify), MD5_LEN)); + + hmac->update(seq, SEQ_SZ); // seq_num + inner[0] = content; // type + inner[SIZEOF_ENUM] = ssl.getSecurity().get_connection().version_.major_; + inner[SIZEOF_ENUM + SIZEOF_ENUM] = + ssl.getSecurity().get_connection().version_.minor_; // version + memcpy(&inner[SIZEOF_ENUM + VERSION_SZ], length, LENGTH_SZ); // length + hmac->update(inner, sizeof(inner)); + hmac->get_digest(digest, buffer, sz); // content +} + + +// compute TLSv1 PRF (pseudo random function using HMAC) +void PRF(byte* digest, uint digLen, const byte* secret, uint secLen, + const byte* label, uint labLen, const byte* seed, uint seedLen) +{ + uint half = (secLen + 1) / 2; + + output_buffer md5_half(half); + output_buffer sha_half(half); + output_buffer labelSeed(labLen + seedLen); + + md5_half.write(secret, half); + sha_half.write(secret + half - secLen % 2, half); + labelSeed.write(label, labLen); + labelSeed.write(seed, seedLen); + + output_buffer md5_result(digLen); + output_buffer sha_result(digLen); + + p_hash(md5_result, md5_half, labelSeed, md5); + p_hash(sha_result, sha_half, labelSeed, sha); + + md5_result.set_current(0); + sha_result.set_current(0); + get_xor(digest, digLen, md5_result, sha_result); +} + + +// build certificate hashes +void build_certHashes(SSL& ssl, Hashes& hashes) +{ + // store current states, building requires get_digest which resets state + MD5 md5(ssl.getHashes().get_MD5()); + SHA sha(ssl.getHashes().get_SHA()); + + if (ssl.isTLS()) { + ssl.useHashes().use_MD5().get_digest(hashes.md5_); + ssl.useHashes().use_SHA().get_digest(hashes.sha_); + } + else { + buildMD5_CertVerify(ssl, hashes.md5_); + buildSHA_CertVerify(ssl, hashes.sha_); + } + + // restore + ssl.useHashes().use_MD5() = md5; + ssl.useHashes().use_SHA() = sha; +} + + + +// do process input requests, return 0 is done, 1 is call again to complete +int DoProcessReply(SSL& ssl) +{ + uint ready = ssl.getSocket().get_ready(); + if (!ready) + ready= 64; + + // add buffered data if its there + input_buffer* buffered = ssl.useBuffers().TakeRawInput(); + uint buffSz = buffered ? buffered->get_size() : 0; + input_buffer buffer(buffSz + ready); + if (buffSz) { + buffer.assign(buffered->get_buffer(), buffSz); + ysDelete(buffered); + buffered = 0; + } + + // add new data + uint read = ssl.useSocket().receive(buffer.get_buffer() + buffSz, ready); + if (read == static_cast(-1)) { + ssl.SetError(receive_error); + return 0; + } else if (read == 0) + return 1; + + buffer.add_size(read); + uint offset = 0; + const MessageFactory& mf = ssl.getFactory().getMessage(); + + // old style sslv2 client hello? + if (ssl.getSecurity().get_parms().entity_ == server_end && + ssl.getStates().getServer() == clientNull) + if (buffer.peek() != handshake) { + ProcessOldClientHello(buffer, ssl); + if (ssl.GetError()) + return 0; + } + + while(!buffer.eof()) { + // each record + RecordLayerHeader hdr; + bool needHdr = false; + + if (static_cast(RECORD_HEADER) > buffer.get_remaining()) + needHdr = true; + else { + buffer >> hdr; + ssl.verifyState(hdr); + } + + if (ssl.GetError()) + return 0; + + // make sure we have enough input in buffer to process this record + if (needHdr || hdr.length_ > buffer.get_remaining()) { + // put header in front for next time processing + uint extra = needHdr ? 0 : RECORD_HEADER; + uint sz = buffer.get_remaining() + extra; + ssl.useBuffers().SetRawInput(NEW_YS input_buffer(sz, + buffer.get_buffer() + buffer.get_current() - extra, sz)); + return 1; + } + + while (buffer.get_current() < hdr.length_ + RECORD_HEADER + offset) { + // each message in record, can be more than 1 if not encrypted + if (ssl.GetError()) + return 0; + + if (ssl.getSecurity().get_parms().pending_ == false) { // cipher on + // sanity check for malicious/corrupted/illegal input + if (buffer.get_remaining() < hdr.length_) { + ssl.SetError(bad_input); + return 0; + } + decrypt_message(ssl, buffer, hdr.length_); + if (ssl.GetError()) + return 0; + } + + mySTL::auto_ptr msg(mf.CreateObject(hdr.type_)); + if (!msg.get()) { + ssl.SetError(factory_error); + return 0; + } + buffer >> *msg; + msg->Process(buffer, ssl); + if (ssl.GetError()) + return 0; + } + offset += hdr.length_ + RECORD_HEADER; + } + return 0; +} + + +// process input requests +void processReply(SSL& ssl) +{ + if (ssl.GetError()) return; + + if (DoProcessReply(ssl)) { + // didn't complete process + if (!ssl.getSocket().IsNonBlocking()) { + // keep trying now, blocking ok + while (!ssl.GetError()) + if (DoProcessReply(ssl) == 0) break; + } + else + // user will have try again later, non blocking + ssl.SetError(YasslError(SSL_ERROR_WANT_READ)); + } +} + + +// send client_hello, no buffering +void sendClientHello(SSL& ssl) +{ + ssl.verifyState(serverNull); + if (ssl.GetError()) return; + + ClientHello ch(ssl.getSecurity().get_connection().version_, + ssl.getSecurity().get_connection().compression_); + RecordLayerHeader rlHeader; + HandShakeHeader hsHeader; + output_buffer out; + + buildClientHello(ssl, ch); + ssl.set_random(ch.get_random(), client_end); + buildHeaders(ssl, hsHeader, rlHeader, ch); + buildOutput(out, rlHeader, hsHeader, ch); + hashHandShake(ssl, out); + + ssl.Send(out.get_buffer(), out.get_size()); +} + + +// send client key exchange +void sendClientKeyExchange(SSL& ssl, BufferOutput buffer) +{ + ssl.verifyState(serverHelloDoneComplete); + if (ssl.GetError()) return; + + ClientKeyExchange ck(ssl); + ck.build(ssl); + ssl.makeMasterSecret(); + + RecordLayerHeader rlHeader; + HandShakeHeader hsHeader; + mySTL::auto_ptr out(NEW_YS output_buffer); + buildHeaders(ssl, hsHeader, rlHeader, ck); + buildOutput(*out.get(), rlHeader, hsHeader, ck); + hashHandShake(ssl, *out.get()); + + if (buffer == buffered) + ssl.addBuffer(out.release()); + else + ssl.Send(out->get_buffer(), out->get_size()); +} + + +// send server key exchange +void sendServerKeyExchange(SSL& ssl, BufferOutput buffer) +{ + if (ssl.GetError()) return; + ServerKeyExchange sk(ssl); + sk.build(ssl); + if (ssl.GetError()) return; + + RecordLayerHeader rlHeader; + HandShakeHeader hsHeader; + mySTL::auto_ptr out(NEW_YS output_buffer); + buildHeaders(ssl, hsHeader, rlHeader, sk); + buildOutput(*out.get(), rlHeader, hsHeader, sk); + hashHandShake(ssl, *out.get()); + + if (buffer == buffered) + ssl.addBuffer(out.release()); + else + ssl.Send(out->get_buffer(), out->get_size()); +} + + +// send change cipher +void sendChangeCipher(SSL& ssl, BufferOutput buffer) +{ + if (ssl.getSecurity().get_parms().entity_ == server_end) { + if (ssl.getSecurity().get_resuming()) + ssl.verifyState(clientKeyExchangeComplete); + else + ssl.verifyState(clientFinishedComplete); + } + if (ssl.GetError()) return; + + ChangeCipherSpec ccs; + RecordLayerHeader rlHeader; + buildHeader(ssl, rlHeader, ccs); + mySTL::auto_ptr out(NEW_YS output_buffer); + buildOutput(*out.get(), rlHeader, ccs); + + if (buffer == buffered) + ssl.addBuffer(out.release()); + else + ssl.Send(out->get_buffer(), out->get_size()); +} + + +// send finished +void sendFinished(SSL& ssl, ConnectionEnd side, BufferOutput buffer) +{ + if (ssl.GetError()) return; + + Finished fin; + buildFinished(ssl, fin, side == client_end ? client : server); + mySTL::auto_ptr out(NEW_YS output_buffer); + cipherFinished(ssl, fin, *out.get()); // hashes handshake + + if (ssl.getSecurity().get_resuming()) { + if (side == server_end) + buildFinished(ssl, ssl.useHashes().use_verify(), client); // client + } + else { + if (!ssl.getSecurity().GetContext()->GetSessionCacheOff()) + GetSessions().add(ssl); // store session + if (side == client_end) + buildFinished(ssl, ssl.useHashes().use_verify(), server); // server + } + ssl.useSecurity().use_connection().CleanMaster(); + + if (buffer == buffered) + ssl.addBuffer(out.release()); + else + ssl.Send(out->get_buffer(), out->get_size()); +} + + +// send data +int sendData(SSL& ssl, const void* buffer, int sz) +{ + int sent = 0; + + if (ssl.GetError() == YasslError(SSL_ERROR_WANT_READ)) + ssl.SetError(no_error); + + if (ssl.GetError() == YasslError(SSL_ERROR_WANT_WRITE)) { + ssl.SetError(no_error); + ssl.SendWriteBuffered(); + if (!ssl.GetError()) { + // advance sent to prvevious sent + plain size just sent + sent = ssl.useBuffers().prevSent + ssl.useBuffers().plainSz; + } + } + + ssl.verfiyHandShakeComplete(); + if (ssl.GetError()) return -1; + + for (;;) { + int len = min(sz - sent, MAX_RECORD_SIZE); + output_buffer out; + input_buffer tmp; + + Data data; + + if (sent == sz) break; + + if (ssl.CompressionOn()) { + if (Compress(static_cast(buffer) + sent, len, + tmp) == -1) { + ssl.SetError(compress_error); + return -1; + } + data.SetData(tmp.get_size(), tmp.get_buffer()); + } + else + data.SetData(len, static_cast(buffer) + sent); + + buildMessage(ssl, out, data); + ssl.Send(out.get_buffer(), out.get_size()); + + if (ssl.GetError()) { + if (ssl.GetError() == YasslError(SSL_ERROR_WANT_WRITE)) { + ssl.useBuffers().plainSz = len; + ssl.useBuffers().prevSent = sent; + } + return -1; + } + sent += len; + } + ssl.useLog().ShowData(sent, true); + return sent; +} + + +// send alert +int sendAlert(SSL& ssl, const Alert& alert) +{ + output_buffer out; + buildAlert(ssl, out, alert); + ssl.Send(out.get_buffer(), out.get_size()); + + return alert.get_length(); +} + + +// process input data +int receiveData(SSL& ssl, Data& data, bool peek) +{ + if (ssl.GetError() == YasslError(SSL_ERROR_WANT_READ)) + ssl.SetError(no_error); + + ssl.verfiyHandShakeComplete(); + if (ssl.GetError()) return -1; + + if (!ssl.HasData()) + processReply(ssl); + + if (peek) + ssl.PeekData(data); + else + ssl.fillData(data); + + ssl.useLog().ShowData(data.get_length()); + if (ssl.GetError()) return -1; + + if (data.get_length() == 0 && ssl.getSocket().WouldBlock()) { + ssl.SetError(YasslError(SSL_ERROR_WANT_READ)); + return SSL_WOULD_BLOCK; + } + return data.get_length(); +} + + +// send server hello +void sendServerHello(SSL& ssl, BufferOutput buffer) +{ + if (ssl.getSecurity().get_resuming()) + ssl.verifyState(clientKeyExchangeComplete); + else + ssl.verifyState(clientHelloComplete); + if (ssl.GetError()) return; + + ServerHello sh(ssl.getSecurity().get_connection().version_, + ssl.getSecurity().get_connection().compression_); + RecordLayerHeader rlHeader; + HandShakeHeader hsHeader; + mySTL::auto_ptr out(NEW_YS output_buffer); + + buildServerHello(ssl, sh); + ssl.set_random(sh.get_random(), server_end); + buildHeaders(ssl, hsHeader, rlHeader, sh); + buildOutput(*out.get(), rlHeader, hsHeader, sh); + hashHandShake(ssl, *out.get()); + + if (buffer == buffered) + ssl.addBuffer(out.release()); + else + ssl.Send(out->get_buffer(), out->get_size()); +} + + +// send server hello done +void sendServerHelloDone(SSL& ssl, BufferOutput buffer) +{ + if (ssl.GetError()) return; + + ServerHelloDone shd; + RecordLayerHeader rlHeader; + HandShakeHeader hsHeader; + mySTL::auto_ptr out(NEW_YS output_buffer); + + buildHeaders(ssl, hsHeader, rlHeader, shd); + buildOutput(*out.get(), rlHeader, hsHeader, shd); + hashHandShake(ssl, *out.get()); + + if (buffer == buffered) + ssl.addBuffer(out.release()); + else + ssl.Send(out->get_buffer(), out->get_size()); +} + + +// send certificate +void sendCertificate(SSL& ssl, BufferOutput buffer) +{ + if (ssl.GetError()) return; + + Certificate cert(ssl.getCrypto().get_certManager().get_cert()); + RecordLayerHeader rlHeader; + HandShakeHeader hsHeader; + mySTL::auto_ptr out(NEW_YS output_buffer); + + buildHeaders(ssl, hsHeader, rlHeader, cert); + buildOutput(*out.get(), rlHeader, hsHeader, cert); + hashHandShake(ssl, *out.get()); + + if (buffer == buffered) + ssl.addBuffer(out.release()); + else + ssl.Send(out->get_buffer(), out->get_size()); +} + + +// send certificate request +void sendCertificateRequest(SSL& ssl, BufferOutput buffer) +{ + if (ssl.GetError()) return; + + CertificateRequest request; + request.Build(); + RecordLayerHeader rlHeader; + HandShakeHeader hsHeader; + mySTL::auto_ptr out(NEW_YS output_buffer); + + buildHeaders(ssl, hsHeader, rlHeader, request); + buildOutput(*out.get(), rlHeader, hsHeader, request); + hashHandShake(ssl, *out.get()); + + if (buffer == buffered) + ssl.addBuffer(out.release()); + else + ssl.Send(out->get_buffer(), out->get_size()); +} + + +// send certificate verify +void sendCertificateVerify(SSL& ssl, BufferOutput buffer) +{ + if (ssl.GetError()) return; + + if(ssl.getCrypto().get_certManager().sendBlankCert()) return; + + CertificateVerify verify; + verify.Build(ssl); + if (ssl.GetError()) return; + + RecordLayerHeader rlHeader; + HandShakeHeader hsHeader; + mySTL::auto_ptr out(NEW_YS output_buffer); + + buildHeaders(ssl, hsHeader, rlHeader, verify); + buildOutput(*out.get(), rlHeader, hsHeader, verify); + hashHandShake(ssl, *out.get()); + + if (buffer == buffered) + ssl.addBuffer(out.release()); + else + ssl.Send(out->get_buffer(), out->get_size()); +} + + +} // namespace -- cgit v1.1