381 lines
8.8 KiB
Go
381 lines
8.8 KiB
Go
// Licensed to Elasticsearch B.V. under one or more contributor
|
|
// license agreements. See the NOTICE file distributed with
|
|
// this work for additional information regarding copyright
|
|
// ownership. Elasticsearch B.V. licenses this file to you under
|
|
// the Apache License, Version 2.0 (the "License"); you may
|
|
// not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing,
|
|
// software distributed under the License is distributed on an
|
|
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
// KIND, either express or implied. See the License for the
|
|
// specific language governing permissions and limitations
|
|
// under the License.
|
|
|
|
package tcp
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/elastic/beats/libbeat/common"
|
|
"github.com/elastic/beats/libbeat/logp"
|
|
"github.com/elastic/beats/libbeat/monitoring"
|
|
|
|
"github.com/elastic/beats/packetbeat/flows"
|
|
"github.com/elastic/beats/packetbeat/protos"
|
|
|
|
"github.com/tsg/gopacket/layers"
|
|
)
|
|
|
|
const TCPMaxDataInStream = 10 * (1 << 20)
|
|
|
|
const (
|
|
TCPDirectionReverse = 0
|
|
TCPDirectionOriginal = 1
|
|
)
|
|
|
|
type TCP struct {
|
|
id uint32
|
|
streams *common.Cache
|
|
portMap map[uint16]protos.Protocol
|
|
protocols protos.Protocols
|
|
expiredConns expirationQueue
|
|
}
|
|
|
|
type expiredConnection struct {
|
|
mod protos.ExpirationAwareTCPPlugin
|
|
conn *TCPConnection
|
|
}
|
|
|
|
type expirationQueue struct {
|
|
mutex sync.Mutex
|
|
conns []expiredConnection
|
|
}
|
|
|
|
type Processor interface {
|
|
Process(flow *flows.FlowID, hdr *layers.TCP, pkt *protos.Packet)
|
|
}
|
|
|
|
var (
|
|
droppedBecauseOfGaps = monitoring.NewInt(nil, "tcp.dropped_because_of_gaps")
|
|
)
|
|
|
|
type seqCompare int
|
|
|
|
const (
|
|
seqLT seqCompare = -1
|
|
seqEq seqCompare = 0
|
|
seqGT seqCompare = 1
|
|
)
|
|
|
|
var (
|
|
debugf = logp.MakeDebug("tcp")
|
|
isDebug = false
|
|
)
|
|
|
|
func (tcp *TCP) getID() uint32 {
|
|
tcp.id++
|
|
return tcp.id
|
|
}
|
|
|
|
func (tcp *TCP) decideProtocol(tuple *common.IPPortTuple) protos.Protocol {
|
|
protocol, exists := tcp.portMap[tuple.SrcPort]
|
|
if exists {
|
|
return protocol
|
|
}
|
|
|
|
protocol, exists = tcp.portMap[tuple.DstPort]
|
|
if exists {
|
|
return protocol
|
|
}
|
|
|
|
return protos.UnknownProtocol
|
|
}
|
|
|
|
func (tcp *TCP) findStream(k common.HashableIPPortTuple) *TCPConnection {
|
|
v := tcp.streams.Get(k)
|
|
if v != nil {
|
|
return v.(*TCPConnection)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type TCPConnection struct {
|
|
id uint32
|
|
tuple *common.IPPortTuple
|
|
protocol protos.Protocol
|
|
tcptuple common.TCPTuple
|
|
tcp *TCP
|
|
|
|
lastSeq [2]uint32
|
|
|
|
// protocols private data
|
|
data protos.ProtocolData
|
|
}
|
|
|
|
type TCPStream struct {
|
|
conn *TCPConnection
|
|
dir uint8
|
|
}
|
|
|
|
func (conn *TCPConnection) String() string {
|
|
return fmt.Sprintf("TcpStream id[%d] tuple[%s] protocol[%s] lastSeq[%d %d]",
|
|
conn.id, conn.tuple, conn.protocol, conn.lastSeq[0], conn.lastSeq[1])
|
|
}
|
|
|
|
func (stream *TCPStream) addPacket(pkt *protos.Packet, tcphdr *layers.TCP) {
|
|
conn := stream.conn
|
|
mod := conn.tcp.protocols.GetTCP(conn.protocol)
|
|
if mod == nil {
|
|
if isDebug {
|
|
protocol := conn.protocol
|
|
debugf("Ignoring protocol for which we have no module loaded: %s",
|
|
protocol)
|
|
}
|
|
return
|
|
}
|
|
|
|
if len(pkt.Payload) > 0 {
|
|
conn.data = mod.Parse(pkt, &conn.tcptuple, stream.dir, conn.data)
|
|
}
|
|
|
|
if tcphdr.FIN {
|
|
conn.data = mod.ReceivedFin(&conn.tcptuple, stream.dir, conn.data)
|
|
}
|
|
}
|
|
|
|
func (stream *TCPStream) gapInStream(nbytes int) (drop bool) {
|
|
conn := stream.conn
|
|
mod := conn.tcp.protocols.GetTCP(conn.protocol)
|
|
conn.data, drop = mod.GapInStream(&conn.tcptuple, stream.dir, nbytes, conn.data)
|
|
return drop
|
|
}
|
|
|
|
func (tcp *TCP) Process(id *flows.FlowID, tcphdr *layers.TCP, pkt *protos.Packet) {
|
|
// This Recover should catch all exceptions in
|
|
// protocol modules.
|
|
defer logp.Recover("Process tcp exception")
|
|
|
|
tcp.expiredConns.notifyAll()
|
|
|
|
stream, created := tcp.getStream(pkt)
|
|
if stream.conn == nil {
|
|
return
|
|
}
|
|
|
|
conn := stream.conn
|
|
if id != nil {
|
|
id.AddConnectionID(uint64(conn.id))
|
|
}
|
|
|
|
if isDebug {
|
|
debugf("tcp flow id: %p", id)
|
|
}
|
|
|
|
if len(pkt.Payload) == 0 && !tcphdr.FIN {
|
|
// return early if packet is not interesting. Still need to find/create
|
|
// stream first in order to update the TCP stream timer
|
|
return
|
|
}
|
|
|
|
tcpStartSeq := tcphdr.Seq
|
|
tcpSeq := tcpStartSeq + uint32(len(pkt.Payload))
|
|
lastSeq := conn.lastSeq[stream.dir]
|
|
if isDebug {
|
|
debugf("pkt.start_seq=%v pkt.last_seq=%v stream.last_seq=%v (len=%d)",
|
|
tcpStartSeq, tcpSeq, lastSeq, len(pkt.Payload))
|
|
}
|
|
|
|
if len(pkt.Payload) > 0 && lastSeq != 0 {
|
|
if tcpSeqBeforeEq(tcpSeq, lastSeq) {
|
|
if isDebug {
|
|
debugf("Ignoring retransmitted segment. pkt.seq=%v len=%v stream.seq=%v",
|
|
tcphdr.Seq, len(pkt.Payload), lastSeq)
|
|
}
|
|
return
|
|
}
|
|
|
|
switch tcpSeqCompare(lastSeq, tcpStartSeq) {
|
|
case seqLT: // lastSeq < tcpStartSeq => Gap in tcp stream detected
|
|
if created {
|
|
break
|
|
}
|
|
|
|
gap := int(tcpStartSeq - lastSeq)
|
|
debugf("Gap in tcp stream. last_seq: %d, seq: %d, gap: %d", lastSeq, tcpStartSeq, gap)
|
|
drop := stream.gapInStream(gap)
|
|
if drop {
|
|
if isDebug {
|
|
debugf("Dropping connection state because of gap")
|
|
}
|
|
droppedBecauseOfGaps.Add(1)
|
|
|
|
// drop application layer connection state and
|
|
// update stream_id for app layer analysers using stream_id for lookups
|
|
conn.id = tcp.getID()
|
|
conn.data = nil
|
|
}
|
|
|
|
case seqGT:
|
|
// lastSeq > tcpStartSeq => overlapping TCP segment detected. shrink packet
|
|
delta := lastSeq - tcpStartSeq
|
|
|
|
if isDebug {
|
|
debugf("Overlapping tcp segment. last_seq %d, seq: %d, delta: %d",
|
|
lastSeq, tcpStartSeq, delta)
|
|
}
|
|
|
|
pkt.Payload = pkt.Payload[delta:]
|
|
tcphdr.Seq += delta
|
|
}
|
|
}
|
|
|
|
conn.lastSeq[stream.dir] = tcpSeq
|
|
stream.addPacket(pkt, tcphdr)
|
|
}
|
|
|
|
func (tcp *TCP) getStream(pkt *protos.Packet) (stream TCPStream, created bool) {
|
|
if conn := tcp.findStream(pkt.Tuple.Hashable()); conn != nil {
|
|
return TCPStream{conn: conn, dir: TCPDirectionOriginal}, false
|
|
}
|
|
|
|
if conn := tcp.findStream(pkt.Tuple.RevHashable()); conn != nil {
|
|
return TCPStream{conn: conn, dir: TCPDirectionReverse}, false
|
|
}
|
|
|
|
protocol := tcp.decideProtocol(&pkt.Tuple)
|
|
if protocol == protos.UnknownProtocol {
|
|
// don't follow
|
|
return TCPStream{}, false
|
|
}
|
|
|
|
var timeout time.Duration
|
|
mod := tcp.protocols.GetTCP(protocol)
|
|
if mod != nil {
|
|
timeout = mod.ConnectionTimeout()
|
|
}
|
|
|
|
if isDebug {
|
|
t := pkt.Tuple
|
|
debugf("Connection src[%s:%d] dst[%s:%d] doesn't exist, creating new",
|
|
t.SrcIP.String(), t.SrcPort,
|
|
t.DstIP.String(), t.DstPort)
|
|
}
|
|
|
|
conn := &TCPConnection{
|
|
id: tcp.getID(),
|
|
tuple: &pkt.Tuple,
|
|
protocol: protocol,
|
|
tcp: tcp}
|
|
conn.tcptuple = common.TCPTupleFromIPPort(conn.tuple, conn.id)
|
|
tcp.streams.PutWithTimeout(pkt.Tuple.Hashable(), conn, timeout)
|
|
return TCPStream{conn: conn, dir: TCPDirectionOriginal}, true
|
|
}
|
|
|
|
func tcpSeqCompare(seq1, seq2 uint32) seqCompare {
|
|
i := int32(seq1 - seq2)
|
|
switch {
|
|
case i == 0:
|
|
return seqEq
|
|
case i < 0:
|
|
return seqLT
|
|
default:
|
|
return seqGT
|
|
}
|
|
}
|
|
|
|
func tcpSeqBefore(seq1 uint32, seq2 uint32) bool {
|
|
return int32(seq1-seq2) < 0
|
|
}
|
|
|
|
func tcpSeqBeforeEq(seq1 uint32, seq2 uint32) bool {
|
|
return int32(seq1-seq2) <= 0
|
|
}
|
|
|
|
func buildPortsMap(plugins map[protos.Protocol]protos.TCPPlugin) (map[uint16]protos.Protocol, error) {
|
|
var res = map[uint16]protos.Protocol{}
|
|
|
|
for proto, protoPlugin := range plugins {
|
|
for _, port := range protoPlugin.GetPorts() {
|
|
oldProto, exists := res[uint16(port)]
|
|
if exists {
|
|
if oldProto == proto {
|
|
continue
|
|
}
|
|
return nil, fmt.Errorf("Duplicate port (%d) exists in %s and %s protocols",
|
|
port, oldProto, proto)
|
|
}
|
|
res[uint16(port)] = proto
|
|
}
|
|
}
|
|
|
|
return res, nil
|
|
}
|
|
|
|
// Creates and returns a new Tcp.
|
|
func NewTCP(p protos.Protocols) (*TCP, error) {
|
|
isDebug = logp.IsDebug("tcp")
|
|
|
|
portMap, err := buildPortsMap(p.GetAllTCP())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
tcp := &TCP{
|
|
protocols: p,
|
|
portMap: portMap,
|
|
}
|
|
tcp.streams = common.NewCacheWithRemovalListener(
|
|
protos.DefaultTransactionExpiration,
|
|
protos.DefaultTransactionHashSize,
|
|
tcp.removalListener)
|
|
|
|
tcp.streams.StartJanitor(protos.DefaultTransactionExpiration)
|
|
if isDebug {
|
|
debugf("tcp", "Port map: %v", portMap)
|
|
}
|
|
|
|
return tcp, nil
|
|
}
|
|
|
|
func (tcp *TCP) removalListener(_ common.Key, value common.Value) {
|
|
conn := value.(*TCPConnection)
|
|
mod := conn.tcp.protocols.GetTCP(conn.protocol)
|
|
if mod != nil {
|
|
awareMod, ok := mod.(protos.ExpirationAwareTCPPlugin)
|
|
if ok {
|
|
tcp.expiredConns.add(awareMod, conn)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (ec *expiredConnection) notify() {
|
|
ec.mod.Expired(&ec.conn.tcptuple, ec.conn.data)
|
|
}
|
|
|
|
func (eq *expirationQueue) add(mod protos.ExpirationAwareTCPPlugin, conn *TCPConnection) {
|
|
eq.mutex.Lock()
|
|
eq.conns = append(eq.conns, expiredConnection{
|
|
mod: mod,
|
|
conn: conn,
|
|
})
|
|
eq.mutex.Unlock()
|
|
}
|
|
|
|
func (eq *expirationQueue) getExpired() (conns []expiredConnection) {
|
|
eq.mutex.Lock()
|
|
conns, eq.conns = eq.conns, nil
|
|
eq.mutex.Unlock()
|
|
return conns
|
|
}
|
|
|
|
func (eq *expirationQueue) notifyAll() {
|
|
for _, expiration := range eq.getExpired() {
|
|
expiration.notify()
|
|
}
|
|
}
|