// Licensed to Elasticsearch B.V. under one or more contributor // license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright // ownership. Elasticsearch B.V. licenses this file to you under // the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package mysql import ( "errors" "fmt" "strings" "time" "github.com/elastic/beats/libbeat/beat" "github.com/elastic/beats/libbeat/common" "github.com/elastic/beats/libbeat/logp" "github.com/elastic/beats/libbeat/monitoring" "github.com/elastic/beats/packetbeat/procs" "github.com/elastic/beats/packetbeat/protos" "github.com/elastic/beats/packetbeat/protos/tcp" ) // Packet types const ( mysqlCmdQuery = 3 ) const maxPayloadSize = 100 * 1024 var ( unmatchedRequests = monitoring.NewInt(nil, "mysql.unmatched_requests") unmatchedResponses = monitoring.NewInt(nil, "mysql.unmatched_responses") ) type mysqlMessage struct { start int end int ts time.Time isRequest bool packetLength uint32 seq uint8 typ uint8 numberOfRows int numberOfFields int size uint64 fields []string rows [][]string tables string isOK bool affectedRows uint64 insertID uint64 isError bool errorCode uint16 errorInfo string query string ignoreMessage bool direction uint8 isTruncated bool tcpTuple common.TCPTuple cmdlineTuple *common.CmdlineTuple raw []byte notes []string } type mysqlTransaction struct { tuple common.TCPTuple src common.Endpoint dst common.Endpoint responseTime int32 ts time.Time query string method string path string // for mysql, Path refers to the mysql table queried bytesOut uint64 bytesIn uint64 notes []string mysql common.MapStr requestRaw string responseRaw string } type mysqlStream struct { data []byte parseOffset int parseState parseState isClient bool message *mysqlMessage } type parseState int const ( mysqlStateStart parseState = iota mysqlStateEatMessage mysqlStateEatFields mysqlStateEatRows mysqlStateMax ) var stateStrings = []string{ "Start", "EatMessage", "EatFields", "EatRows", } func (state parseState) String() string { return stateStrings[state] } type mysqlPlugin struct { // config ports []int maxStoreRows int maxRowLength int sendRequest bool sendResponse bool transactions *common.Cache transactionTimeout time.Duration results protos.Reporter // function pointer for mocking handleMysql func(mysql *mysqlPlugin, m *mysqlMessage, tcp *common.TCPTuple, dir uint8, raw_msg []byte) } func init() { protos.Register("mysql", New) } func New( testMode bool, results protos.Reporter, cfg *common.Config, ) (protos.Plugin, error) { p := &mysqlPlugin{} config := defaultConfig if !testMode { if err := cfg.Unpack(&config); err != nil { return nil, err } } if err := p.init(results, &config); err != nil { return nil, err } return p, nil } func (mysql *mysqlPlugin) init(results protos.Reporter, config *mysqlConfig) error { mysql.setFromConfig(config) mysql.transactions = common.NewCache( mysql.transactionTimeout, protos.DefaultTransactionHashSize) mysql.transactions.StartJanitor(mysql.transactionTimeout) mysql.handleMysql = handleMysql mysql.results = results return nil } func (mysql *mysqlPlugin) setFromConfig(config *mysqlConfig) { mysql.ports = config.Ports mysql.maxRowLength = config.MaxRowLength mysql.maxStoreRows = config.MaxRows mysql.sendRequest = config.SendRequest mysql.sendResponse = config.SendResponse mysql.transactionTimeout = config.TransactionTimeout } func (mysql *mysqlPlugin) getTransaction(k common.HashableTCPTuple) *mysqlTransaction { v := mysql.transactions.Get(k) if v != nil { return v.(*mysqlTransaction) } return nil } func (mysql *mysqlPlugin) GetPorts() []int { return mysql.ports } func (stream *mysqlStream) prepareForNewMessage() { stream.data = stream.data[stream.parseOffset:] stream.parseState = mysqlStateStart stream.parseOffset = 0 stream.isClient = false stream.message = nil } func mysqlMessageParser(s *mysqlStream) (bool, bool) { logp.Debug("mysqldetailed", "MySQL parser called. parseState = %s", s.parseState) m := s.message for s.parseOffset < len(s.data) { switch s.parseState { case mysqlStateStart: m.start = s.parseOffset if len(s.data[s.parseOffset:]) < 5 { logp.Warn("MySQL Message too short. Ignore it.") return false, false } hdr := s.data[s.parseOffset : s.parseOffset+5] m.packetLength = uint32(hdr[0]) | uint32(hdr[1])<<8 | uint32(hdr[2])<<16 m.seq = hdr[3] m.typ = hdr[4] logp.Debug("mysqldetailed", "MySQL Header: Packet length %d, Seq %d, Type=%d", m.packetLength, m.seq, m.typ) if m.seq == 0 { // starts Command Phase if m.typ == mysqlCmdQuery { // parse request m.isRequest = true m.start = s.parseOffset s.parseState = mysqlStateEatMessage } else { // ignore command m.ignoreMessage = true s.parseState = mysqlStateEatMessage } if !s.isClient { s.isClient = true } } else if !s.isClient { // parse response m.isRequest = false if hdr[4] == 0x00 || hdr[4] == 0xfe { logp.Debug("mysqldetailed", "Received OK response") m.start = s.parseOffset s.parseState = mysqlStateEatMessage m.isOK = true } else if hdr[4] == 0xff { logp.Debug("mysqldetailed", "Received ERR response") m.start = s.parseOffset s.parseState = mysqlStateEatMessage m.isError = true } else if m.packetLength == 1 { logp.Debug("mysqldetailed", "Query response. Number of fields %d", hdr[4]) m.numberOfFields = int(hdr[4]) m.start = s.parseOffset s.parseOffset += 5 s.parseState = mysqlStateEatFields } else { // something else. ignore m.ignoreMessage = true s.parseState = mysqlStateEatMessage } } else { // something else, not expected logp.Debug("mysql", "Unexpected MySQL message of type %d received.", m.typ) return false, false } case mysqlStateEatMessage: if len(s.data[s.parseOffset:]) < int(m.packetLength)+4 { // wait for more data return true, false } s.parseOffset += 4 //header s.parseOffset += int(m.packetLength) m.end = s.parseOffset if m.isRequest { m.query = string(s.data[m.start+5 : m.end]) } else if m.isOK { // affected rows affectedRows, off, complete, err := readLinteger(s.data, m.start+5) if !complete { return true, false } if err != nil { logp.Debug("mysql", "Error on read_linteger: %s", err) return false, false } m.affectedRows = affectedRows // last insert id insertID, _, complete, err := readLinteger(s.data, off) if !complete { return true, false } if err != nil { logp.Debug("mysql", "Error on read_linteger: %s", err) return false, false } m.insertID = insertID } else if m.isError { // int<1>header (0xff) // int<2>error code // string[1] sql state marker // string[5] sql state // string error message m.errorCode = uint16(s.data[m.start+6])<<8 | uint16(s.data[m.start+5]) m.errorInfo = string(s.data[m.start+8:m.start+13]) + ": " + string(s.data[m.start+13:]) } m.size = uint64(m.end - m.start) logp.Debug("mysqldetailed", "Message complete. remaining=%d", len(s.data[s.parseOffset:])) return true, true case mysqlStateEatFields: if len(s.data[s.parseOffset:]) < 4 { // wait for more return true, false } hdr := s.data[s.parseOffset : s.parseOffset+4] m.packetLength = uint32(hdr[0]) | uint32(hdr[1])<<8 | uint32(hdr[2])<<16 m.seq = hdr[3] logp.Debug("mysqldetailed", "Fields: packet length %d, packet number %d", m.packetLength, m.seq) if len(s.data[s.parseOffset:]) >= int(m.packetLength)+4 { s.parseOffset += 4 // header if s.data[s.parseOffset] == 0xfe { logp.Debug("mysqldetailed", "Received EOF packet") // EOF marker s.parseOffset += int(m.packetLength) s.parseState = mysqlStateEatRows } else { _ /* catalog */, off, complete, err := readLstring(s.data, s.parseOffset) if !complete { return true, false } if err != nil { logp.Debug("mysql", "Error on read_lstring: %s", err) return false, false } db /*schema */, off, complete, err := readLstring(s.data, off) if !complete { return true, false } if err != nil { logp.Debug("mysql", "Error on read_lstring: %s", err) return false, false } table /* table */, _ /*off*/, complete, err := readLstring(s.data, off) if !complete { return true, false } if err != nil { logp.Debug("mysql", "Error on read_lstring: %s", err) return false, false } dbTable := string(db) + "." + string(table) if len(m.tables) == 0 { m.tables = dbTable } else if !strings.Contains(m.tables, dbTable) { m.tables = m.tables + ", " + dbTable } logp.Debug("mysqldetailed", "db=%s, table=%s", db, table) s.parseOffset += int(m.packetLength) // go to next field } } else { // wait for more return true, false } case mysqlStateEatRows: if len(s.data[s.parseOffset:]) < 4 { // wait for more return true, false } hdr := s.data[s.parseOffset : s.parseOffset+4] m.packetLength = uint32(hdr[0]) | uint32(hdr[1])<<8 | uint32(hdr[2])<<16 m.seq = hdr[3] logp.Debug("mysqldetailed", "Rows: packet length %d, packet number %d", m.packetLength, m.seq) if len(s.data[s.parseOffset:]) < int(m.packetLength)+4 { // wait for more return true, false } s.parseOffset += 4 //header if s.data[s.parseOffset] == 0xfe { logp.Debug("mysqldetailed", "Received EOF packet") // EOF marker s.parseOffset += int(m.packetLength) if m.end == 0 { m.end = s.parseOffset } else { m.isTruncated = true } if !m.isError { // in case the response was sent successfully m.isOK = true } m.size = uint64(m.end - m.start) return true, true } s.parseOffset += int(m.packetLength) if m.end == 0 && s.parseOffset > maxPayloadSize { // only send up to here, but read until the end m.end = s.parseOffset } m.numberOfRows++ // go to next row } } return true, false } // messageGap is called when a gap of size `nbytes` is found in the // tcp stream. Returns true if there is already enough data in the message // read so far that we can use it further in the stack. func (mysql *mysqlPlugin) messageGap(s *mysqlStream, nbytes int) (complete bool) { m := s.message switch s.parseState { case mysqlStateStart, mysqlStateEatMessage: // not enough data yet to be useful return false case mysqlStateEatFields, mysqlStateEatRows: // enough data here m.end = s.parseOffset if m.isRequest { m.notes = append(m.notes, "Packet loss while capturing the request") } else { m.notes = append(m.notes, "Packet loss while capturing the response") } return true } return true } type mysqlPrivateData struct { data [2]*mysqlStream } // Called when the parser has identified a full message. func (mysql *mysqlPlugin) messageComplete(tcptuple *common.TCPTuple, dir uint8, stream *mysqlStream) { // all ok, ship it msg := stream.data[stream.message.start:stream.message.end] if !stream.message.ignoreMessage { mysql.handleMysql(mysql, stream.message, tcptuple, dir, msg) } // and reset message stream.prepareForNewMessage() } func (mysql *mysqlPlugin) ConnectionTimeout() time.Duration { return mysql.transactionTimeout } func (mysql *mysqlPlugin) Parse(pkt *protos.Packet, tcptuple *common.TCPTuple, dir uint8, private protos.ProtocolData) protos.ProtocolData { defer logp.Recover("ParseMysql exception") priv := mysqlPrivateData{} if private != nil { var ok bool priv, ok = private.(mysqlPrivateData) if !ok { priv = mysqlPrivateData{} } } if priv.data[dir] == nil { priv.data[dir] = &mysqlStream{ data: pkt.Payload, message: &mysqlMessage{ts: pkt.Ts}, } } else { // concatenate bytes priv.data[dir].data = append(priv.data[dir].data, pkt.Payload...) if len(priv.data[dir].data) > tcp.TCPMaxDataInStream { logp.Debug("mysql", "Stream data too large, dropping TCP stream") priv.data[dir] = nil return priv } } stream := priv.data[dir] for len(stream.data) > 0 { if stream.message == nil { stream.message = &mysqlMessage{ts: pkt.Ts} } ok, complete := mysqlMessageParser(priv.data[dir]) //logp.Debug("mysqldetailed", "mysqlMessageParser returned ok=%b complete=%b", ok, complete) if !ok { // drop this tcp stream. Will retry parsing with the next // segment in it priv.data[dir] = nil logp.Debug("mysql", "Ignore MySQL message. Drop tcp stream. Try parsing with the next segment") return priv } if complete { mysql.messageComplete(tcptuple, dir, stream) } else { // wait for more data break } } return priv } func (mysql *mysqlPlugin) GapInStream(tcptuple *common.TCPTuple, dir uint8, nbytes int, private protos.ProtocolData) (priv protos.ProtocolData, drop bool) { defer logp.Recover("GapInStream(mysql) exception") if private == nil { return private, false } mysqlData, ok := private.(mysqlPrivateData) if !ok { return private, false } stream := mysqlData.data[dir] if stream == nil || stream.message == nil { // nothing to do return private, false } if mysql.messageGap(stream, nbytes) { // we need to publish from here mysql.messageComplete(tcptuple, dir, stream) } // we always drop the TCP stream. Because it's binary and len based, // there are too few cases in which we could recover the stream (maybe // for very large blobs, leaving that as TODO) return private, true } func (mysql *mysqlPlugin) ReceivedFin(tcptuple *common.TCPTuple, dir uint8, private protos.ProtocolData) protos.ProtocolData { // TODO: check if we have data pending and either drop it to free // memory or send it up the stack. return private } func handleMysql(mysql *mysqlPlugin, m *mysqlMessage, tcptuple *common.TCPTuple, dir uint8, rawMsg []byte) { m.tcpTuple = *tcptuple m.direction = dir m.cmdlineTuple = procs.ProcWatcher.FindProcessesTupleTCP(tcptuple.IPPort()) m.raw = rawMsg if m.isRequest { mysql.receivedMysqlRequest(m) } else { mysql.receivedMysqlResponse(m) } } func (mysql *mysqlPlugin) receivedMysqlRequest(msg *mysqlMessage) { tuple := msg.tcpTuple trans := mysql.getTransaction(tuple.Hashable()) if trans != nil { if trans.mysql != nil { logp.Debug("mysql", "Two requests without a Response. Dropping old request: %s", trans.mysql) unmatchedRequests.Add(1) } } else { trans = &mysqlTransaction{tuple: tuple} mysql.transactions.Put(tuple.Hashable(), trans) } trans.ts = msg.ts trans.src, trans.dst = common.MakeEndpointPair(msg.tcpTuple.BaseTuple, msg.cmdlineTuple) if msg.direction == tcp.TCPDirectionReverse { trans.src, trans.dst = trans.dst, trans.src } // Extract the method, by simply taking the first word and // making it upper case. query := strings.Trim(msg.query, " \r\n\t") index := strings.IndexAny(query, " \r\n\t") var method string if index > 0 { method = strings.ToUpper(query[:index]) } else { method = strings.ToUpper(query) } trans.query = query trans.method = method trans.mysql = common.MapStr{} trans.notes = msg.notes // save Raw message trans.requestRaw = msg.query trans.bytesIn = msg.size } func (mysql *mysqlPlugin) receivedMysqlResponse(msg *mysqlMessage) { trans := mysql.getTransaction(msg.tcpTuple.Hashable()) if trans == nil { logp.Debug("mysql", "Response from unknown transaction. Ignoring.") unmatchedResponses.Add(1) return } // check if the request was received if trans.mysql == nil { logp.Debug("mysql", "Response from unknown transaction. Ignoring.") unmatchedResponses.Add(1) return } // save json details trans.mysql.Update(common.MapStr{ "affected_rows": msg.affectedRows, "insert_id": msg.insertID, "num_rows": msg.numberOfRows, "num_fields": msg.numberOfFields, "iserror": msg.isError, "error_code": msg.errorCode, "error_message": msg.errorInfo, }) trans.bytesOut = msg.size trans.path = msg.tables trans.responseTime = int32(msg.ts.Sub(trans.ts).Nanoseconds() / 1e6) // resp_time in milliseconds // save Raw message if len(msg.raw) > 0 { fields, rows := mysql.parseMysqlResponse(msg.raw) trans.responseRaw = common.DumpInCSVFormat(fields, rows) } trans.notes = append(trans.notes, msg.notes...) mysql.publishTransaction(trans) mysql.transactions.Delete(trans.tuple.Hashable()) logp.Debug("mysql", "Mysql transaction completed: %s", trans.mysql) logp.Debug("mysql", "%s", trans.responseRaw) } func (mysql *mysqlPlugin) parseMysqlResponse(data []byte) ([]string, [][]string) { length, err := readLength(data, 0) if err != nil { logp.Warn("Invalid response: %v", err) return []string{}, [][]string{} } if length < 1 { logp.Warn("Warning: Skipping empty Response") return []string{}, [][]string{} } fields := []string{} rows := [][]string{} if len(data) < 5 { logp.Warn("Invalid response: data less than 4 bytes") return []string{}, [][]string{} } if data[4] == 0x00 { // OK response } else if data[4] == 0xff { // Error response } else { offset := 5 logp.Debug("mysql", "Data len: %d", len(data)) // Read fields for { length, err = readLength(data, offset) if err != nil { logp.Warn("Invalid response: %v", err) return []string{}, [][]string{} } if len(data[offset:]) < 5 { logp.Warn("Invalid response.") return []string{}, [][]string{} } if data[offset+4] == 0xfe { // EOF offset += length + 4 break } _ /* catalog */, off, complete, err := readLstring(data, offset+4) if err != nil || !complete { logp.Debug("mysql", "Reading field: %v %v", err, complete) return fields, rows } _ /*database*/, off, complete, err = readLstring(data, off) if err != nil || !complete { logp.Debug("mysql", "Reading field: %v %v", err, complete) return fields, rows } _ /*table*/, off, complete, err = readLstring(data, off) if err != nil || !complete { logp.Debug("mysql", "Reading field: %v %v", err, complete) return fields, rows } _ /*org table*/, off, complete, err = readLstring(data, off) if err != nil || !complete { logp.Debug("mysql", "Reading field: %v %v", err, complete) return fields, rows } name, off, complete, err := readLstring(data, off) if err != nil || !complete { logp.Debug("mysql", "Reading field: %v %v", err, complete) return fields, rows } _ /* org name */, _ /*off*/, complete, err = readLstring(data, off) if err != nil || !complete { logp.Debug("mysql", "Reading field: %v %v", err, complete) return fields, rows } fields = append(fields, string(name)) offset += length + 4 if len(data) < offset { logp.Warn("Invalid response.") return []string{}, [][]string{} } } // Read rows for offset < len(data) { var row []string var rowLen int if len(data[offset:]) < 5 { logp.Warn("Invalid response.") break } if data[offset+4] == 0xfe { // EOF offset += length + 4 break } length, err = readLength(data, offset) if err != nil { logp.Warn("Invalid response: %v", err) break } off := offset + 4 // skip length + packet number start := off for off < start+length { var text []byte if data[off] == 0xfb { text = []byte("NULL") off++ } else { var err error var complete bool text, off, complete, err = readLstring(data, off) if err != nil || !complete { logp.Debug("mysql", "Error parsing rows: %s %b", err, complete) // nevertheless, return what we have so far return fields, rows } } if rowLen < mysql.maxRowLength { if rowLen+len(text) > mysql.maxRowLength { text = text[:mysql.maxRowLength-rowLen] } row = append(row, string(text)) rowLen += len(text) } } logp.Debug("mysqldetailed", "Append row: %v", row) rows = append(rows, row) if len(rows) >= mysql.maxStoreRows { break } offset += length + 4 } } return fields, rows } func (mysql *mysqlPlugin) publishTransaction(t *mysqlTransaction) { if mysql.results == nil { return } logp.Debug("mysql", "mysql.results exists") fields := common.MapStr{} fields["type"] = "mysql" if t.mysql["iserror"].(bool) { fields["status"] = common.ERROR_STATUS } else { fields["status"] = common.OK_STATUS } fields["responsetime"] = t.responseTime if mysql.sendRequest { fields["request"] = t.requestRaw } if mysql.sendResponse { fields["response"] = t.responseRaw } fields["method"] = t.method fields["query"] = t.query fields["mysql"] = t.mysql fields["path"] = t.path fields["bytes_out"] = t.bytesOut fields["bytes_in"] = t.bytesIn if len(t.notes) > 0 { fields["notes"] = t.notes } fields["src"] = &t.src fields["dst"] = &t.dst mysql.results(beat.Event{ Timestamp: t.ts, Fields: fields, }) } func readLstring(data []byte, offset int) ([]byte, int, bool, error) { length, off, complete, err := readLinteger(data, offset) if err != nil { return nil, 0, false, err } if !complete || len(data[off:]) < int(length) { return nil, 0, false, nil } return data[off : off+int(length)], off + int(length), true, nil } func readLinteger(data []byte, offset int) (uint64, int, bool, error) { if len(data) < offset+1 { return 0, 0, false, nil } switch data[offset] { case 0xfe: if len(data[offset+1:]) < 8 { return 0, 0, false, nil } return uint64(data[offset+1]) | uint64(data[offset+2])<<8 | uint64(data[offset+2])<<16 | uint64(data[offset+3])<<24 | uint64(data[offset+4])<<32 | uint64(data[offset+5])<<40 | uint64(data[offset+6])<<48 | uint64(data[offset+7])<<56, offset + 9, true, nil case 0xfd: if len(data[offset+1:]) < 3 { return 0, 0, false, nil } return uint64(data[offset+1]) | uint64(data[offset+2])<<8 | uint64(data[offset+3])<<16, offset + 4, true, nil case 0xfc: if len(data[offset+1:]) < 2 { return 0, 0, false, nil } return uint64(data[offset+1]) | uint64(data[offset+2])<<8, offset + 3, true, nil } if uint64(data[offset]) >= 0xfb { return 0, 0, false, fmt.Errorf("Unexpected value in read_linteger") } return uint64(data[offset]), offset + 1, true, nil } // Read a mysql length field (3 bytes LE) func readLength(data []byte, offset int) (int, error) { if len(data[offset:]) < 3 { return 0, errors.New("Data too small to contain a valid length") } length := uint32(data[offset]) | uint32(data[offset+1])<<8 | uint32(data[offset+2])<<16 return int(length), nil }