610 lines
16 KiB
Go
610 lines
16 KiB
Go
|
// Licensed to Elasticsearch B.V. under one or more contributor
|
||
|
// license agreements. See the NOTICE file distributed with
|
||
|
// this work for additional information regarding copyright
|
||
|
// ownership. Elasticsearch B.V. licenses this file to you under
|
||
|
// the Apache License, Version 2.0 (the "License"); you may
|
||
|
// not use this file except in compliance with the License.
|
||
|
// You may obtain a copy of the License at
|
||
|
//
|
||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||
|
//
|
||
|
// Unless required by applicable law or agreed to in writing,
|
||
|
// software distributed under the License is distributed on an
|
||
|
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||
|
// KIND, either express or implied. See the License for the
|
||
|
// specific language governing permissions and limitations
|
||
|
// under the License.
|
||
|
|
||
|
package tls
|
||
|
|
||
|
import (
|
||
|
"crypto/dsa"
|
||
|
"crypto/ecdsa"
|
||
|
"crypto/rsa"
|
||
|
"crypto/x509"
|
||
|
"crypto/x509/pkix"
|
||
|
"encoding/hex"
|
||
|
"fmt"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/elastic/beats/libbeat/common"
|
||
|
"github.com/elastic/beats/libbeat/common/streambuf"
|
||
|
"github.com/elastic/beats/libbeat/common/x509util"
|
||
|
"github.com/elastic/beats/libbeat/logp"
|
||
|
)
|
||
|
|
||
|
type direction uint8
|
||
|
|
||
|
const (
|
||
|
dirUnknown direction = iota
|
||
|
dirClient
|
||
|
dirServer
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
maxTLSRecordLength = (1 << 14) + 2048
|
||
|
// For safety, ignore handshake messages longer than 64k (same as stdlib)
|
||
|
maxHandshakeSize = 1 << 16
|
||
|
recordHeaderSize = 5
|
||
|
handshakeHeaderSize = 4
|
||
|
helloHeaderLength = 7
|
||
|
randomDataLength = 28
|
||
|
)
|
||
|
|
||
|
type recordType uint8
|
||
|
|
||
|
const (
|
||
|
recordTypeChangeCipherSpec recordType = 20
|
||
|
recordTypeAlert = 21
|
||
|
recordTypeHandshake = 22
|
||
|
recordTypeApplicationData = 23
|
||
|
)
|
||
|
|
||
|
type handshakeType uint8
|
||
|
|
||
|
const (
|
||
|
helloRequest handshakeType = 0
|
||
|
clientHello = 1
|
||
|
serverHello = 2
|
||
|
certificate = 11
|
||
|
serverKeyExchange = 12
|
||
|
certificateRequest = 13
|
||
|
clientKeyExchange = 16
|
||
|
)
|
||
|
|
||
|
type parserResult int8
|
||
|
|
||
|
const (
|
||
|
resultOK parserResult = iota
|
||
|
resultFailed
|
||
|
resultMore
|
||
|
resultEncrypted
|
||
|
)
|
||
|
|
||
|
type tlsTicket struct {
|
||
|
present bool
|
||
|
value string
|
||
|
}
|
||
|
|
||
|
type parser struct {
|
||
|
// Buffer to accumulate records until a full handshake message
|
||
|
// is received
|
||
|
handshakeBuf streambuf.Buffer
|
||
|
|
||
|
direction direction
|
||
|
alerts []alert
|
||
|
certificates []*x509.Certificate
|
||
|
hello *helloMessage
|
||
|
|
||
|
// If this end of the connection (server) asked the other end (client)
|
||
|
// for a certificate
|
||
|
certRequested bool
|
||
|
|
||
|
// If a key-exchange message has been sent. Used to detect session resumption
|
||
|
keyExchanged bool
|
||
|
}
|
||
|
|
||
|
type tlsVersion struct {
|
||
|
major, minor uint8
|
||
|
}
|
||
|
|
||
|
type recordHeader struct {
|
||
|
recordType recordType
|
||
|
version tlsVersion
|
||
|
length uint16
|
||
|
}
|
||
|
|
||
|
type handshakeHeader struct {
|
||
|
handshakeType handshakeType
|
||
|
length int
|
||
|
}
|
||
|
|
||
|
type helloMessage struct {
|
||
|
version tlsVersion
|
||
|
timestamp uint32
|
||
|
sessionID string
|
||
|
ticket tlsTicket
|
||
|
supported struct {
|
||
|
cipherSuites []cipherSuite
|
||
|
compression []compressionMethod
|
||
|
}
|
||
|
selected struct {
|
||
|
cipherSuite cipherSuite
|
||
|
compression compressionMethod
|
||
|
}
|
||
|
extensions Extensions
|
||
|
}
|
||
|
|
||
|
func readRecordHeader(buf *streambuf.Buffer) (*recordHeader, error) {
|
||
|
var (
|
||
|
header recordHeader
|
||
|
err error
|
||
|
record uint8
|
||
|
)
|
||
|
if record, err = buf.ReadNetUint8At(0); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
header.recordType = recordType(record)
|
||
|
if header.version.major, err = buf.ReadNetUint8At(1); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if header.version.minor, err = buf.ReadNetUint8At(2); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if header.length, err = buf.ReadNetUint16At(3); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return &header, nil
|
||
|
}
|
||
|
|
||
|
func readHandshakeHeader(buf *streambuf.Buffer) (*handshakeHeader, error) {
|
||
|
var err error
|
||
|
var len8, typ uint8
|
||
|
var len16 uint16
|
||
|
if typ, err = buf.ReadNetUint8At(0); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if len8, err = buf.ReadNetUint8At(1); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
if len16, err = buf.ReadNetUint16At(2); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return &handshakeHeader{handshakeType(typ),
|
||
|
int(len16) | (int(len8) << 16)}, nil
|
||
|
}
|
||
|
|
||
|
func (header *recordHeader) String() string {
|
||
|
return fmt.Sprintf("recordHeader type[%v] version[%v] length[%d]",
|
||
|
header.recordType, header.version, header.length)
|
||
|
}
|
||
|
|
||
|
func (header *recordHeader) isValid() bool {
|
||
|
return header.version.major == 3 && header.length <= maxTLSRecordLength
|
||
|
}
|
||
|
|
||
|
func (hello helloMessage) toMap() common.MapStr {
|
||
|
m := common.MapStr{
|
||
|
"version": fmt.Sprintf("%d.%d", hello.version.major, hello.version.minor),
|
||
|
}
|
||
|
if len(hello.sessionID) != 0 {
|
||
|
m["session_id"] = hello.sessionID
|
||
|
}
|
||
|
|
||
|
if len(hello.supported.cipherSuites) > 0 || len(hello.supported.compression) > 0 {
|
||
|
ciphers := make([]string, len(hello.supported.cipherSuites))
|
||
|
for idx, code := range hello.supported.cipherSuites {
|
||
|
ciphers[idx] = code.String()
|
||
|
}
|
||
|
m["supported_ciphers"] = ciphers
|
||
|
|
||
|
comp := make([]string, len(hello.supported.compression))
|
||
|
for idx, code := range hello.supported.compression {
|
||
|
comp[idx] = code.String()
|
||
|
}
|
||
|
m["supported_compression_methods"] = comp
|
||
|
} else {
|
||
|
m["selected_cipher"] = hello.selected.cipherSuite.String()
|
||
|
m["selected_compression_method"] = hello.selected.compression.String()
|
||
|
}
|
||
|
|
||
|
if hello.extensions.Parsed != nil {
|
||
|
m["extensions"] = hello.extensions.Parsed
|
||
|
}
|
||
|
return m
|
||
|
}
|
||
|
|
||
|
func (parser *parser) parse(buf *streambuf.Buffer) parserResult {
|
||
|
|
||
|
for buf.Avail(recordHeaderSize) {
|
||
|
|
||
|
header, err := readRecordHeader(buf)
|
||
|
if err != nil || !header.isValid() {
|
||
|
if err != nil {
|
||
|
logp.Warn("internal buffer error: %v", err)
|
||
|
}
|
||
|
return resultFailed
|
||
|
}
|
||
|
|
||
|
limit := recordHeaderSize + int(header.length)
|
||
|
if !buf.Avail(limit) {
|
||
|
// wait for complete record
|
||
|
return resultMore
|
||
|
}
|
||
|
|
||
|
switch header.recordType {
|
||
|
case recordTypeChangeCipherSpec: // single message of size 1 (byte 1)
|
||
|
if isDebug {
|
||
|
debugf("handshake completed")
|
||
|
}
|
||
|
// discard remaining data for this stream (encrypted)
|
||
|
buf.Advance(buf.Len())
|
||
|
return resultEncrypted
|
||
|
|
||
|
case recordTypeHandshake:
|
||
|
if isDebug {
|
||
|
debugf("got handshake record of size %d", header.length)
|
||
|
}
|
||
|
if err = parser.bufferHandshake(buf, int(header.length)); err != nil {
|
||
|
logp.Warn("Error parsing handshake message: %v", err)
|
||
|
return resultFailed
|
||
|
}
|
||
|
|
||
|
case recordTypeAlert:
|
||
|
if err = parser.parseAlert(newBufferView(buf, recordHeaderSize, int(header.length))); err != nil {
|
||
|
logp.Warn("Error parsing alert message: %v", err)
|
||
|
return resultFailed
|
||
|
}
|
||
|
|
||
|
case recordTypeApplicationData:
|
||
|
// TODO: Request / Response analytics
|
||
|
if isDebug {
|
||
|
debugf("ignoring application data length %d", header.length)
|
||
|
}
|
||
|
|
||
|
default:
|
||
|
if isDebug {
|
||
|
debugf("ignoring record type %d length %d", header.recordType, header.length)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
buf.Advance(limit)
|
||
|
}
|
||
|
|
||
|
if buf.Len() == 0 {
|
||
|
return resultOK
|
||
|
}
|
||
|
return resultMore
|
||
|
}
|
||
|
|
||
|
func (parser *parser) bufferHandshake(buf *streambuf.Buffer, length int) error {
|
||
|
// TODO: parse in-place if message in received buffer is complete
|
||
|
if err := parser.handshakeBuf.Append(buf.Bytes()[recordHeaderSize : recordHeaderSize+length]); err != nil {
|
||
|
logp.Warn("failed appending to buffer: %v", err)
|
||
|
// Discard buffer
|
||
|
parser.handshakeBuf.Init(nil, false)
|
||
|
return err
|
||
|
}
|
||
|
for parser.handshakeBuf.Avail(handshakeHeaderSize) {
|
||
|
// type
|
||
|
header, err := readHandshakeHeader(&parser.handshakeBuf)
|
||
|
if err != nil {
|
||
|
logp.Warn("read failed: %v", err)
|
||
|
parser.handshakeBuf.Init(nil, false)
|
||
|
return err
|
||
|
}
|
||
|
if header.length > maxHandshakeSize {
|
||
|
// Discard buffer
|
||
|
parser.handshakeBuf.Init(nil, false)
|
||
|
return fmt.Errorf("message too large (%d bytes)", header.length)
|
||
|
}
|
||
|
limit := handshakeHeaderSize + header.length
|
||
|
if limit > parser.handshakeBuf.Len() {
|
||
|
break
|
||
|
}
|
||
|
if !parser.parseHandshake(header.handshakeType,
|
||
|
bufferView{&parser.handshakeBuf, handshakeHeaderSize, limit}) {
|
||
|
parser.handshakeBuf.Advance(limit)
|
||
|
return fmt.Errorf("bad handshake %+v", header)
|
||
|
}
|
||
|
parser.handshakeBuf.Advance(limit)
|
||
|
}
|
||
|
if parser.handshakeBuf.Len() == 0 {
|
||
|
parser.handshakeBuf.Reset()
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (parser *parser) setDirection(dir direction) {
|
||
|
if parser.direction != dir && parser.direction != dirUnknown {
|
||
|
logp.Warn("client/server identification mismatch")
|
||
|
}
|
||
|
parser.direction = dir
|
||
|
}
|
||
|
|
||
|
func (parser *parser) parseHandshake(handshakeType handshakeType, buffer bufferView) bool {
|
||
|
if isDebug {
|
||
|
debugf("got handshake message %v [%d]", handshakeType, buffer.length())
|
||
|
}
|
||
|
switch handshakeType {
|
||
|
case helloRequest:
|
||
|
parser.setDirection(dirServer)
|
||
|
return parseHelloRequest(buffer)
|
||
|
|
||
|
case clientHello:
|
||
|
parser.setDirection(dirClient)
|
||
|
if parser.hello = parseClientHello(buffer); parser.hello == nil {
|
||
|
return false
|
||
|
}
|
||
|
return true
|
||
|
|
||
|
case serverHello:
|
||
|
parser.setDirection(dirServer)
|
||
|
if parser.hello = parseServerHello(buffer); parser.hello == nil {
|
||
|
return false
|
||
|
}
|
||
|
return true
|
||
|
|
||
|
case certificate:
|
||
|
certs := parseCertificates(buffer)
|
||
|
parser.certificates = append(parser.certificates, certs...)
|
||
|
|
||
|
case certificateRequest:
|
||
|
parser.setDirection(dirServer)
|
||
|
parser.certRequested = true
|
||
|
|
||
|
case clientKeyExchange:
|
||
|
parser.setDirection(dirClient)
|
||
|
parser.keyExchanged = true
|
||
|
|
||
|
case serverKeyExchange:
|
||
|
parser.setDirection(dirServer)
|
||
|
parser.keyExchanged = true
|
||
|
}
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
func parseHelloRequest(buffer bufferView) bool {
|
||
|
if buffer.length() != 0 {
|
||
|
logp.Warn("non-empty hello request")
|
||
|
}
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
func parseCommonHello(buffer bufferView, dest *helloMessage) (int, bool) {
|
||
|
var sessionIDLength uint8
|
||
|
|
||
|
if !buffer.read8(0, &dest.version.major) ||
|
||
|
!buffer.read8(1, &dest.version.minor) ||
|
||
|
!buffer.read32Net(2, &dest.timestamp) ||
|
||
|
// ignore 28 random bytes
|
||
|
!buffer.read8(6+randomDataLength, &sessionIDLength) {
|
||
|
logp.Warn("failed reading hello message")
|
||
|
return 0, false
|
||
|
}
|
||
|
|
||
|
if dest.version.major != 3 {
|
||
|
logp.Warn("Not a TLS hello (reported version %d.%d)",
|
||
|
dest.version.major, dest.version.minor)
|
||
|
return 0, false
|
||
|
}
|
||
|
if sessionIDLength > 32 {
|
||
|
logp.Warn("Not a TLS hello (session id length %d out of bounds)", sessionIDLength)
|
||
|
return 0, false
|
||
|
}
|
||
|
|
||
|
if bytes := buffer.readBytes(7+randomDataLength, int(sessionIDLength)); len(bytes) == int(sessionIDLength) {
|
||
|
dest.sessionID = hex.EncodeToString(bytes)
|
||
|
} else {
|
||
|
logp.Warn("Not a TLS hello (failed reading session ID)")
|
||
|
return 0, false
|
||
|
}
|
||
|
return helloHeaderLength + randomDataLength + int(sessionIDLength), true
|
||
|
}
|
||
|
|
||
|
func (hello *helloMessage) parseExtensions(buffer bufferView) {
|
||
|
hello.extensions = ParseExtensions(buffer)
|
||
|
if ticket, err := hello.extensions.Parsed.GetValue("session_ticket"); err == nil {
|
||
|
if value, ok := ticket.(string); ok {
|
||
|
hello.ticket.present = true
|
||
|
hello.ticket.value = value
|
||
|
} else {
|
||
|
logp.Err("tls ticket data type error")
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func parseClientHello(buffer bufferView) *helloMessage {
|
||
|
var result helloMessage
|
||
|
pos, ok := parseCommonHello(buffer, &result)
|
||
|
if !ok {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
var cipherSuitesLength uint16
|
||
|
if !buffer.read16Net(pos, &cipherSuitesLength) {
|
||
|
logp.Warn("failed parsing client hello cipher suite length")
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
for base := pos + 2; base < pos+2+int(cipherSuitesLength); base += 2 {
|
||
|
var cipher uint16
|
||
|
if !buffer.read16Net(base, &cipher) {
|
||
|
logp.Warn("failed parsing client hello cipher suite")
|
||
|
return nil
|
||
|
}
|
||
|
if !isGreaseValue(cipher) {
|
||
|
result.supported.cipherSuites = append(result.supported.cipherSuites, cipherSuite(cipher))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
pos += 2 + int(cipherSuitesLength)
|
||
|
var compMethodsLength uint8
|
||
|
if !buffer.read8(pos, &compMethodsLength) {
|
||
|
logp.Warn("failed parsing client hello compression methods length")
|
||
|
return nil
|
||
|
}
|
||
|
limit := pos + 1 + int(compMethodsLength)
|
||
|
for base := pos + 1; base < limit; base++ {
|
||
|
var method uint8
|
||
|
if !buffer.read8(base, &method) {
|
||
|
logp.Warn("failed parsing client hello compression methods")
|
||
|
return nil
|
||
|
}
|
||
|
result.supported.compression = append(result.supported.compression, compressionMethod(method))
|
||
|
}
|
||
|
|
||
|
result.parseExtensions(buffer.subview(limit, buffer.limit-limit))
|
||
|
return &result
|
||
|
}
|
||
|
|
||
|
func parseServerHello(buffer bufferView) *helloMessage {
|
||
|
var result helloMessage
|
||
|
pos, ok := parseCommonHello(buffer, &result)
|
||
|
if !ok {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
var cipher uint16
|
||
|
var compression uint8
|
||
|
if !buffer.read16Net(pos, &cipher) ||
|
||
|
!buffer.read8(pos+2, &compression) {
|
||
|
return nil
|
||
|
}
|
||
|
result.selected.cipherSuite = cipherSuite(cipher)
|
||
|
result.selected.compression = compressionMethod(compression)
|
||
|
result.parseExtensions(buffer.subview(pos+3, buffer.limit-pos-3))
|
||
|
return &result
|
||
|
}
|
||
|
|
||
|
func parseCertificates(buffer bufferView) []*x509.Certificate {
|
||
|
var totalLen uint32
|
||
|
if !buffer.read24Net(0, &totalLen) || int(totalLen+3) != buffer.length() {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
var certs []*x509.Certificate
|
||
|
|
||
|
for pos, limit := 3, int(totalLen)+3; pos+3 <= limit; {
|
||
|
var certLen uint32
|
||
|
if !buffer.read24Net(pos, &certLen) || pos+3+int(certLen) > limit {
|
||
|
return nil
|
||
|
}
|
||
|
cert := buffer.readBytes(pos+3, int(certLen))
|
||
|
if len(cert) != int(certLen) {
|
||
|
return nil
|
||
|
}
|
||
|
parsed, err := x509.ParseCertificate(cert)
|
||
|
if err != nil {
|
||
|
return nil
|
||
|
}
|
||
|
certs = append(certs, parsed)
|
||
|
pos += 3 + int(certLen)
|
||
|
}
|
||
|
return certs
|
||
|
}
|
||
|
|
||
|
func (version tlsVersion) String() string {
|
||
|
if version.major == 3 {
|
||
|
if version.minor > 0 {
|
||
|
return fmt.Sprintf("TLS 1.%d", version.minor-1)
|
||
|
}
|
||
|
return "SSL 3.0"
|
||
|
}
|
||
|
return fmt.Sprintf("(raw %d.%d)", version.major, version.minor)
|
||
|
}
|
||
|
|
||
|
func getKeySize(key interface{}) int {
|
||
|
if key == nil {
|
||
|
return 0
|
||
|
}
|
||
|
switch pubKey := key.(type) {
|
||
|
case *rsa.PublicKey:
|
||
|
if n := pubKey.N; n != nil {
|
||
|
return n.BitLen()
|
||
|
}
|
||
|
|
||
|
case *dsa.PublicKey:
|
||
|
if p := pubKey.Parameters.P; p != nil {
|
||
|
return p.BitLen()
|
||
|
}
|
||
|
if y := pubKey.Y; y != nil {
|
||
|
return y.BitLen()
|
||
|
}
|
||
|
|
||
|
case *ecdsa.PublicKey:
|
||
|
if params := pubKey.Params(); params != nil {
|
||
|
return params.BitSize
|
||
|
}
|
||
|
if y := pubKey.Y; y != nil {
|
||
|
return y.BitLen()
|
||
|
}
|
||
|
}
|
||
|
return 0
|
||
|
}
|
||
|
|
||
|
// certToMap takes an x509 cert and converts it into a map. If includeRaw is set
|
||
|
// to true a PEM encoded copy of the cert is encoded into the map as well.
|
||
|
func certToMap(cert *x509.Certificate, includeRaw bool) common.MapStr {
|
||
|
certMap := common.MapStr{
|
||
|
"signature_algorithm": cert.SignatureAlgorithm.String(),
|
||
|
"public_key_algorithm": toString(cert.PublicKeyAlgorithm),
|
||
|
"version": cert.Version,
|
||
|
"serial_number": cert.SerialNumber.Text(10),
|
||
|
"issuer": toMap(&cert.Issuer),
|
||
|
"subject": toMap(&cert.Subject),
|
||
|
"not_before": cert.NotBefore,
|
||
|
"not_after": cert.NotAfter,
|
||
|
}
|
||
|
if keySize := getKeySize(cert.PublicKey); keySize > 0 {
|
||
|
certMap["public_key_size"] = keySize
|
||
|
}
|
||
|
san := make([]string, 0, len(cert.DNSNames)+len(cert.IPAddresses)+len(cert.EmailAddresses))
|
||
|
san = append(append(san, cert.DNSNames...), cert.EmailAddresses...)
|
||
|
for _, ip := range cert.IPAddresses {
|
||
|
san = append(san, ip.String())
|
||
|
}
|
||
|
if len(san) > 0 {
|
||
|
certMap["alternative_names"] = san
|
||
|
}
|
||
|
if includeRaw {
|
||
|
certMap["raw"] = x509util.CertToPEMString(cert)
|
||
|
}
|
||
|
return certMap
|
||
|
}
|
||
|
|
||
|
func toMap(name *pkix.Name) common.MapStr {
|
||
|
result := common.MapStr{}
|
||
|
fields := []struct {
|
||
|
name string
|
||
|
value interface{}
|
||
|
}{
|
||
|
{"country", name.Country},
|
||
|
{"organization", name.Organization},
|
||
|
{"organizational_unit", name.OrganizationalUnit},
|
||
|
{"locality", name.Locality},
|
||
|
{"province", name.Province},
|
||
|
{"postal_code", name.PostalCode},
|
||
|
{"serial_number", name.SerialNumber},
|
||
|
{"common_name", name.CommonName},
|
||
|
{"street_address", name.StreetAddress},
|
||
|
}
|
||
|
for _, field := range fields {
|
||
|
var str string
|
||
|
switch value := field.value.(type) {
|
||
|
case string:
|
||
|
str = value
|
||
|
case []string:
|
||
|
str = strings.Join(value, " ")
|
||
|
}
|
||
|
if len(str) > 0 {
|
||
|
result[field.name] = str
|
||
|
}
|
||
|
}
|
||
|
return result
|
||
|
}
|
||
|
|
||
|
func (parser *parser) hasInfo() bool {
|
||
|
return parser.hello != nil || len(parser.alerts) != 0 || len(parser.certificates) != 0
|
||
|
}
|