// Licensed to Elasticsearch B.V. under one or more contributor // license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright // ownership. Elasticsearch B.V. licenses this file to you under // the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package tcp import ( "fmt" "sync" "time" "github.com/elastic/beats/libbeat/common" "github.com/elastic/beats/libbeat/logp" "github.com/elastic/beats/libbeat/monitoring" "github.com/elastic/beats/packetbeat/flows" "github.com/elastic/beats/packetbeat/protos" "github.com/tsg/gopacket/layers" ) const TCPMaxDataInStream = 10 * (1 << 20) const ( TCPDirectionReverse = 0 TCPDirectionOriginal = 1 ) type TCP struct { id uint32 streams *common.Cache portMap map[uint16]protos.Protocol protocols protos.Protocols expiredConns expirationQueue } type expiredConnection struct { mod protos.ExpirationAwareTCPPlugin conn *TCPConnection } type expirationQueue struct { mutex sync.Mutex conns []expiredConnection } type Processor interface { Process(flow *flows.FlowID, hdr *layers.TCP, pkt *protos.Packet) } var ( droppedBecauseOfGaps = monitoring.NewInt(nil, "tcp.dropped_because_of_gaps") ) type seqCompare int const ( seqLT seqCompare = -1 seqEq seqCompare = 0 seqGT seqCompare = 1 ) var ( debugf = logp.MakeDebug("tcp") isDebug = false ) func (tcp *TCP) getID() uint32 { tcp.id++ return tcp.id } func (tcp *TCP) decideProtocol(tuple *common.IPPortTuple) protos.Protocol { protocol, exists := tcp.portMap[tuple.SrcPort] if exists { return protocol } protocol, exists = tcp.portMap[tuple.DstPort] if exists { return protocol } return protos.UnknownProtocol } func (tcp *TCP) findStream(k common.HashableIPPortTuple) *TCPConnection { v := tcp.streams.Get(k) if v != nil { return v.(*TCPConnection) } return nil } type TCPConnection struct { id uint32 tuple *common.IPPortTuple protocol protos.Protocol tcptuple common.TCPTuple tcp *TCP lastSeq [2]uint32 // protocols private data data protos.ProtocolData } type TCPStream struct { conn *TCPConnection dir uint8 } func (conn *TCPConnection) String() string { return fmt.Sprintf("TcpStream id[%d] tuple[%s] protocol[%s] lastSeq[%d %d]", conn.id, conn.tuple, conn.protocol, conn.lastSeq[0], conn.lastSeq[1]) } func (stream *TCPStream) addPacket(pkt *protos.Packet, tcphdr *layers.TCP) { conn := stream.conn mod := conn.tcp.protocols.GetTCP(conn.protocol) if mod == nil { if isDebug { protocol := conn.protocol debugf("Ignoring protocol for which we have no module loaded: %s", protocol) } return } if len(pkt.Payload) > 0 { conn.data = mod.Parse(pkt, &conn.tcptuple, stream.dir, conn.data) } if tcphdr.FIN { conn.data = mod.ReceivedFin(&conn.tcptuple, stream.dir, conn.data) } } func (stream *TCPStream) gapInStream(nbytes int) (drop bool) { conn := stream.conn mod := conn.tcp.protocols.GetTCP(conn.protocol) conn.data, drop = mod.GapInStream(&conn.tcptuple, stream.dir, nbytes, conn.data) return drop } func (tcp *TCP) Process(id *flows.FlowID, tcphdr *layers.TCP, pkt *protos.Packet) { // This Recover should catch all exceptions in // protocol modules. defer logp.Recover("Process tcp exception") tcp.expiredConns.notifyAll() stream, created := tcp.getStream(pkt) if stream.conn == nil { return } conn := stream.conn if id != nil { id.AddConnectionID(uint64(conn.id)) } if isDebug { debugf("tcp flow id: %p", id) } if len(pkt.Payload) == 0 && !tcphdr.FIN { // return early if packet is not interesting. Still need to find/create // stream first in order to update the TCP stream timer return } tcpStartSeq := tcphdr.Seq tcpSeq := tcpStartSeq + uint32(len(pkt.Payload)) lastSeq := conn.lastSeq[stream.dir] if isDebug { debugf("pkt.start_seq=%v pkt.last_seq=%v stream.last_seq=%v (len=%d)", tcpStartSeq, tcpSeq, lastSeq, len(pkt.Payload)) } if len(pkt.Payload) > 0 && lastSeq != 0 { if tcpSeqBeforeEq(tcpSeq, lastSeq) { if isDebug { debugf("Ignoring retransmitted segment. pkt.seq=%v len=%v stream.seq=%v", tcphdr.Seq, len(pkt.Payload), lastSeq) } return } switch tcpSeqCompare(lastSeq, tcpStartSeq) { case seqLT: // lastSeq < tcpStartSeq => Gap in tcp stream detected if created { break } gap := int(tcpStartSeq - lastSeq) debugf("Gap in tcp stream. last_seq: %d, seq: %d, gap: %d", lastSeq, tcpStartSeq, gap) drop := stream.gapInStream(gap) if drop { if isDebug { debugf("Dropping connection state because of gap") } droppedBecauseOfGaps.Add(1) // drop application layer connection state and // update stream_id for app layer analysers using stream_id for lookups conn.id = tcp.getID() conn.data = nil } case seqGT: // lastSeq > tcpStartSeq => overlapping TCP segment detected. shrink packet delta := lastSeq - tcpStartSeq if isDebug { debugf("Overlapping tcp segment. last_seq %d, seq: %d, delta: %d", lastSeq, tcpStartSeq, delta) } pkt.Payload = pkt.Payload[delta:] tcphdr.Seq += delta } } conn.lastSeq[stream.dir] = tcpSeq stream.addPacket(pkt, tcphdr) } func (tcp *TCP) getStream(pkt *protos.Packet) (stream TCPStream, created bool) { if conn := tcp.findStream(pkt.Tuple.Hashable()); conn != nil { return TCPStream{conn: conn, dir: TCPDirectionOriginal}, false } if conn := tcp.findStream(pkt.Tuple.RevHashable()); conn != nil { return TCPStream{conn: conn, dir: TCPDirectionReverse}, false } protocol := tcp.decideProtocol(&pkt.Tuple) if protocol == protos.UnknownProtocol { // don't follow return TCPStream{}, false } var timeout time.Duration mod := tcp.protocols.GetTCP(protocol) if mod != nil { timeout = mod.ConnectionTimeout() } if isDebug { t := pkt.Tuple debugf("Connection src[%s:%d] dst[%s:%d] doesn't exist, creating new", t.SrcIP.String(), t.SrcPort, t.DstIP.String(), t.DstPort) } conn := &TCPConnection{ id: tcp.getID(), tuple: &pkt.Tuple, protocol: protocol, tcp: tcp} conn.tcptuple = common.TCPTupleFromIPPort(conn.tuple, conn.id) tcp.streams.PutWithTimeout(pkt.Tuple.Hashable(), conn, timeout) return TCPStream{conn: conn, dir: TCPDirectionOriginal}, true } func tcpSeqCompare(seq1, seq2 uint32) seqCompare { i := int32(seq1 - seq2) switch { case i == 0: return seqEq case i < 0: return seqLT default: return seqGT } } func tcpSeqBefore(seq1 uint32, seq2 uint32) bool { return int32(seq1-seq2) < 0 } func tcpSeqBeforeEq(seq1 uint32, seq2 uint32) bool { return int32(seq1-seq2) <= 0 } func buildPortsMap(plugins map[protos.Protocol]protos.TCPPlugin) (map[uint16]protos.Protocol, error) { var res = map[uint16]protos.Protocol{} for proto, protoPlugin := range plugins { for _, port := range protoPlugin.GetPorts() { oldProto, exists := res[uint16(port)] if exists { if oldProto == proto { continue } return nil, fmt.Errorf("Duplicate port (%d) exists in %s and %s protocols", port, oldProto, proto) } res[uint16(port)] = proto } } return res, nil } // Creates and returns a new Tcp. func NewTCP(p protos.Protocols) (*TCP, error) { isDebug = logp.IsDebug("tcp") portMap, err := buildPortsMap(p.GetAllTCP()) if err != nil { return nil, err } tcp := &TCP{ protocols: p, portMap: portMap, } tcp.streams = common.NewCacheWithRemovalListener( protos.DefaultTransactionExpiration, protos.DefaultTransactionHashSize, tcp.removalListener) tcp.streams.StartJanitor(protos.DefaultTransactionExpiration) if isDebug { debugf("tcp", "Port map: %v", portMap) } return tcp, nil } func (tcp *TCP) removalListener(_ common.Key, value common.Value) { conn := value.(*TCPConnection) mod := conn.tcp.protocols.GetTCP(conn.protocol) if mod != nil { awareMod, ok := mod.(protos.ExpirationAwareTCPPlugin) if ok { tcp.expiredConns.add(awareMod, conn) } } } func (ec *expiredConnection) notify() { ec.mod.Expired(&ec.conn.tcptuple, ec.conn.data) } func (eq *expirationQueue) add(mod protos.ExpirationAwareTCPPlugin, conn *TCPConnection) { eq.mutex.Lock() eq.conns = append(eq.conns, expiredConnection{ mod: mod, conn: conn, }) eq.mutex.Unlock() } func (eq *expirationQueue) getExpired() (conns []expiredConnection) { eq.mutex.Lock() conns, eq.conns = eq.conns, nil eq.mutex.Unlock() return conns } func (eq *expirationQueue) notifyAll() { for _, expiration := range eq.getExpired() { expiration.notify() } }