From e52ba1f05d2125d0fa066a89bf7d6f1a07039129 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?M=2E=20Mert=20Y=C4=B1ld=C4=B1ran?= <mehmet@up9.com>
Date: Mon, 8 Aug 2022 03:48:19 -0700
Subject: [PATCH] Add `AF_PACKET` support (#1052)

* Add `AF_PACKET` support

* Update `.gitignore`

* Support both `libpcap` and `AF_PACKET` at the same time

* Fix linter errors

* Fix a bug that introduced while fixing a linter error

* Revert the changes related to `MaxBufferedPages` prefixed consts

* #run_acceptance_tests

* #run_acceptance_tests

* Revert channel buffer size #run_acceptance_tests

* Revert "Revert channel buffer size #run_acceptance_tests"

This reverts commit e62c3844cd6d4808fd2266ecdbd2106ae950fc6a.

* Increase `cy.wait` from `500` to `1000` #run_acceptance_tests

* Fix the `pcapHandle` handle

* Revert "Increase `cy.wait` from `500` to `1000` #run_acceptance_tests"

This reverts commit 938c550e7226578fe2cd290c2531edf4330a28ed.

* #run_acceptance_tests

* Handle the merge conflicts

* Add `AF_XDP` support

* Implement `Close()` of `AF_XDP` and fix linter errors

* Fix `NewIPProtoProgram` function and internet protocol number

* Pipe the packet stream from every network interface using `*pcapgo.NgReader` and `*pcapgo.NgWriter`

Implement `SetDecoder` and `SetBPF` methods.

* Fix `NewNgReader` call

* Implement `Stats` method

* Rebroadcast to the XDP socket

* Add `-packet-capture` flag and make `AF_PACKET`, `AF_XDP` optional

* #run_acceptance_tests

* Fix `newAfXdpHandle` method

* #run_acceptance_tests

* Update tap/xdp/ipproto.c

Co-authored-by: Nimrod Gilboa Markevich <59927337+nimrod-up9@users.noreply.github.com>

* Update tap/xdp/ipproto.c

Co-authored-by: Nimrod Gilboa Markevich <59927337+nimrod-up9@users.noreply.github.com>

* Update tap/xdp/ipproto.c

Co-authored-by: Nimrod Gilboa Markevich <59927337+nimrod-up9@users.noreply.github.com>

* Fix several issues

* Update tap/xdp/ipproto.c

Co-authored-by: Nimrod Gilboa Markevich <59927337+nimrod-up9@users.noreply.github.com>

* Fix `ipproto.c`

* Remove `AF_XDP`

* Comment on frameSize

Co-authored-by: Nimrod Gilboa Markevich <59927337+nimrod-up9@users.noreply.github.com>
---
 .gitignore                            |   1 +
 Dockerfile                            |   3 +-
 cli/config/configStructs/tapConfig.go |   1 +
 tap/go.mod                            |   2 +-
 tap/passive_tapper.go                 |  19 ++--
 tap/source/handle_af_packet.go        | 152 ++++++++++++++++++++++++++
 tap/source/handle_pcap.go             |  97 ++++++++++++++++
 tap/source/netns_packet_source.go     |   8 +-
 tap/source/packet_source_manager.go   |  21 ++--
 tap/source/tcp_packet_source.go       | 111 ++++++++++---------
 10 files changed, 335 insertions(+), 80 deletions(-)
 create mode 100644 tap/source/handle_af_packet.go
 create mode 100644 tap/source/handle_pcap.go

diff --git a/.gitignore b/.gitignore
index 95357b3c5..55adc15fd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -53,6 +53,7 @@ tap/extensions/*/expect
 **/node_modules/**
 **/dist/**
 *.editorconfig
+ui/up9-mizu-common-0.0.0.tgz
 
 # Ignore *.log files
 *.log
diff --git a/Dockerfile b/Dockerfile
index 9c7beb205..fcc95f2cd 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -104,8 +104,7 @@ ARG BUILD_TIMESTAMP
 ARG VER=0.0
 
 WORKDIR /app/tap/tlstapper
-
-RUN rm tlstapper_bpf*
+RUN rm *_bpfel_*
 RUN GOARCH=${BUILDARCH} go generate tls_tapper.go
 
 WORKDIR /app/agent-build
diff --git a/cli/config/configStructs/tapConfig.go b/cli/config/configStructs/tapConfig.go
index 79615e0ec..edf10c510 100644
--- a/cli/config/configStructs/tapConfig.go
+++ b/cli/config/configStructs/tapConfig.go
@@ -51,6 +51,7 @@ type TapConfig struct {
 	TapperResources       shared.Resources `yaml:"tapper-resources"`
 	ServiceMesh           bool             `yaml:"service-mesh" default:"false"`
 	Tls                   bool             `yaml:"tls" default:"false"`
+	PacketCapture         string           `yaml:"packet-capture" default:"libpcap"`
 	Profiler              bool             `yaml:"profiler" default:"false"`
 	MaxLiveStreams        int              `yaml:"max-live-streams" default:"500"`
 }
diff --git a/tap/go.mod b/tap/go.mod
index 4407610c2..fe9aecd57 100644
--- a/tap/go.mod
+++ b/tap/go.mod
@@ -16,6 +16,7 @@ require (
 	github.com/up9inc/mizu/tap/api v0.0.0
 	github.com/up9inc/mizu/tap/dbgctl v0.0.0
 	github.com/vishvananda/netns v0.0.0-20211101163701-50045581ed74
+	golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd
 	k8s.io/api v0.23.3
 )
 
@@ -33,7 +34,6 @@ require (
 	github.com/tklauser/go-sysconf v0.3.10 // indirect
 	github.com/tklauser/numcpus v0.4.0 // indirect
 	github.com/yusufpapurcu/wmi v1.2.2 // indirect
-	golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd // indirect
 	golang.org/x/sys v0.0.0-20220207234003-57398862261d // indirect
 	golang.org/x/text v0.3.7 // indirect
 	gopkg.in/inf.v0 v0.9.1 // indirect
diff --git a/tap/passive_tapper.go b/tap/passive_tapper.go
index e27925736..1add57fd5 100644
--- a/tap/passive_tapper.go
+++ b/tap/passive_tapper.go
@@ -51,11 +51,13 @@ var maxLiveStreams = flag.Int("max-live-streams", 500, "Maximum live streams to
 var iface = flag.String("i", "en0", "Interface to read packets from")
 var fname = flag.String("r", "", "Filename to read from, overrides -i")
 var snaplen = flag.Int("s", 65536, "Snap length (number of bytes max to read per packet")
-var tstype = flag.String("timestamp_type", "", "Type of timestamps to use")
+var targetSizeMb = flag.Int("target-size-mb", 8, "AF_PACKET target block size (MB)")
+var tstype = flag.String("timestamp-type", "", "Type of timestamps to use")
 var promisc = flag.Bool("promisc", true, "Set promiscuous mode")
 var staleTimeoutSeconds = flag.Int("staletimout", 120, "Max time in seconds to keep connections which don't transmit data")
 var servicemesh = flag.Bool("servicemesh", false, "Record decrypted traffic if the cluster is configured with a service mesh and with mtls")
 var tls = flag.Bool("tls", false, "Enable TLS tapper")
+var packetCapture = flag.String("packet-capture", "libpcap", "Packet capture backend. Possible values: libpcap, af_packet")
 
 var memprofile = flag.String("memprofile", "", "Write memory profile")
 
@@ -210,16 +212,17 @@ func initializePacketSources() error {
 	}
 
 	behaviour := source.TcpPacketSourceBehaviour{
-		SnapLength:  *snaplen,
-		Promisc:     *promisc,
-		Tstype:      *tstype,
-		DecoderName: *decoder,
-		Lazy:        *lazy,
-		BpfFilter:   bpffilter,
+		SnapLength:   *snaplen,
+		TargetSizeMb: *targetSizeMb,
+		Promisc:      *promisc,
+		Tstype:       *tstype,
+		DecoderName:  *decoder,
+		Lazy:         *lazy,
+		BpfFilter:    bpffilter,
 	}
 
 	var err error
-	packetSourceManager, err = source.NewPacketSourceManager(*procfs, *fname, *iface, *servicemesh, tapTargets, behaviour, !*nodefrag, mainPacketInputChan)
+	packetSourceManager, err = source.NewPacketSourceManager(*procfs, *fname, *iface, *servicemesh, tapTargets, behaviour, !*nodefrag, *packetCapture, mainPacketInputChan)
 	return err
 }
 
diff --git a/tap/source/handle_af_packet.go b/tap/source/handle_af_packet.go
new file mode 100644
index 000000000..953cc7428
--- /dev/null
+++ b/tap/source/handle_af_packet.go
@@ -0,0 +1,152 @@
+package source
+
+import (
+	"fmt"
+	"os"
+	"time"
+
+	"github.com/google/gopacket"
+	"github.com/google/gopacket/afpacket"
+	"github.com/google/gopacket/layers"
+	"github.com/google/gopacket/pcap"
+	"golang.org/x/net/bpf"
+)
+
+type afPacketHandle struct {
+	source        gopacket.ZeroCopyPacketDataSource
+	capture       *afpacket.TPacket
+	decoder       gopacket.Decoder
+	decodeOptions gopacket.DecodeOptions
+}
+
+func (h *afPacketHandle) NextPacket() (packet gopacket.Packet, err error) {
+	var data []byte
+	var ci gopacket.CaptureInfo
+	data, ci, err = h.source.ZeroCopyReadPacketData()
+	if err != nil {
+		return
+	}
+
+	packet = gopacket.NewPacket(data, h.decoder, h.decodeOptions)
+	m := packet.Metadata()
+	m.CaptureInfo = ci
+	m.Truncated = m.Truncated || ci.CaptureLength < ci.Length
+	return
+}
+
+func (h *afPacketHandle) SetDecoder(decoder gopacket.Decoder, lazy bool, noCopy bool) {
+	h.decoder = decoder
+	h.decodeOptions = gopacket.DecodeOptions{Lazy: lazy, NoCopy: noCopy}
+}
+
+func (h *afPacketHandle) SetBPF(expr string) (err error) {
+	var pcapBPF []pcap.BPFInstruction
+	pcapBPF, err = pcap.CompileBPFFilter(layers.LinkTypeEthernet, 65535, expr)
+	if err != nil {
+		return
+	}
+	bpfIns := []bpf.RawInstruction{}
+	for _, ins := range pcapBPF {
+		bpfIns2 := bpf.RawInstruction{
+			Op: ins.Code,
+			Jt: ins.Jt,
+			Jf: ins.Jf,
+			K:  ins.K,
+		}
+		bpfIns = append(bpfIns, bpfIns2)
+	}
+	err = h.capture.SetBPF(bpfIns)
+	return
+}
+
+func (h *afPacketHandle) LinkType() layers.LinkType {
+	return layers.LinkTypeEthernet
+}
+
+func (h *afPacketHandle) Stats() (packetsReceived uint, packetsDropped uint, err error) {
+	var stats afpacket.SocketStatsV3
+	_, stats, err = h.capture.SocketStats()
+	packetsReceived = stats.Packets()
+	packetsDropped = stats.Drops()
+	return
+}
+
+func (h *afPacketHandle) Close() (err error) {
+	h.capture.Close()
+	return
+}
+
+func newAfpacketHandle(device string, targetSizeMb int, snaplen int) (handle Handle, err error) {
+	snaplen -= 1
+	if snaplen < 0 {
+		snaplen = 0
+	}
+	szFrame, szBlock, numBlocks, err := afpacketComputeSize(targetSizeMb, snaplen, os.Getpagesize())
+	if err != nil {
+		return
+	}
+	var capture *afpacket.TPacket
+	capture, err = newAfpacket(device, szFrame, szBlock, numBlocks, false, pcap.BlockForever)
+	if err != nil {
+		return
+	}
+	handle = &afPacketHandle{
+		capture: capture,
+		source:  gopacket.ZeroCopyPacketDataSource(capture),
+	}
+	return
+}
+
+func newAfpacket(device string, snaplen int, block_size int, num_blocks int,
+	useVLAN bool, timeout time.Duration) (*afpacket.TPacket, error) {
+
+	var h *afpacket.TPacket
+	var err error
+
+	if device == "any" {
+		h, err = afpacket.NewTPacket(
+			afpacket.OptFrameSize(snaplen),
+			afpacket.OptBlockSize(block_size),
+			afpacket.OptNumBlocks(num_blocks),
+			afpacket.OptAddVLANHeader(useVLAN),
+			afpacket.OptPollTimeout(timeout),
+			afpacket.SocketRaw,
+			afpacket.TPacketVersion3)
+	} else {
+		h, err = afpacket.NewTPacket(
+			afpacket.OptInterface(device),
+			afpacket.OptFrameSize(snaplen),
+			afpacket.OptBlockSize(block_size),
+			afpacket.OptNumBlocks(num_blocks),
+			afpacket.OptAddVLANHeader(useVLAN),
+			afpacket.OptPollTimeout(timeout),
+			afpacket.SocketRaw,
+			afpacket.TPacketVersion3)
+	}
+	return h, err
+}
+
+// afpacketComputeSize computes the block_size and the num_blocks in such a way that the
+// allocated mmap buffer is close to but smaller than target_size_mb.
+// The restriction is that the block_size must be divisible by both the
+// frame size and page size.
+func afpacketComputeSize(targetSizeMb int, snaplen int, pageSize int) (
+	frameSize int, blockSize int, numBlocks int, err error) {
+
+	// frameSize calculation was taken from gopacket's afpacket.go
+	if snaplen < pageSize {
+		frameSize = pageSize / (pageSize / snaplen)
+	} else {
+		frameSize = (snaplen/pageSize + 1) * pageSize
+	}
+
+	// 128 is the default from the gopacket library so just use that
+	blockSize = frameSize * 128
+	numBlocks = (targetSizeMb * 1024 * 1024) / blockSize
+
+	if numBlocks == 0 {
+		return 0, 0, 0, fmt.Errorf("Interface buffersize is too small")
+	}
+
+	return frameSize, blockSize, numBlocks, nil
+}
diff --git a/tap/source/handle_pcap.go b/tap/source/handle_pcap.go
new file mode 100644
index 000000000..7f4ce40f3
--- /dev/null
+++ b/tap/source/handle_pcap.go
@@ -0,0 +1,97 @@
+package source
+
+import (
+	"fmt"
+	"time"
+
+	"github.com/google/gopacket"
+	"github.com/google/gopacket/layers"
+	"github.com/google/gopacket/pcap"
+)
+
+type pcapHandle struct {
+	source  *gopacket.PacketSource
+	capture *pcap.Handle
+}
+
+func (h *pcapHandle) NextPacket() (packet gopacket.Packet, err error) {
+	return h.source.NextPacket()
+}
+func (h *pcapHandle) SetDecoder(decoder gopacket.Decoder, lazy bool, noCopy bool) {
+	h.source = gopacket.NewPacketSource(h.capture, decoder)
+	h.source.Lazy = lazy
+	h.source.NoCopy = noCopy
+}
+
+func (h *pcapHandle) SetBPF(expr string) (err error) {
+	return h.capture.SetBPFFilter(expr)
+}
+
+func (h *pcapHandle) LinkType() layers.LinkType {
+	return h.capture.LinkType()
+}
+
+func (h *pcapHandle) Stats() (packetsReceived uint, packetsDropped uint, err error) {
+	var stats *pcap.Stats
+	stats, err = h.capture.Stats()
+	packetsReceived = uint(stats.PacketsReceived)
+	packetsDropped = uint(stats.PacketsDropped)
+	return
+}
+
+func (h *pcapHandle) Close() (err error) {
+	h.capture.Close()
+	return
+}
+
+func newPcapHandle(filename string, device string, snaplen int, promisc bool, tstype string) (handle Handle, err error) {
+	var capture *pcap.Handle
+
+	if filename != "" {
+		if capture, err = pcap.OpenOffline(filename); err != nil {
+			err = fmt.Errorf("PCAP OpenOffline error: %v", err)
+			return
+		}
+	} else {
+		// This is a little complicated because we want to allow all possible options
+		// for creating the packet capture handle... instead of all this you can
+		// just call pcap.OpenLive if you want a simple handle.
+		var inactive *pcap.InactiveHandle
+		inactive, err = pcap.NewInactiveHandle(device)
+		if err != nil {
+			err = fmt.Errorf("could not create: %v", err)
+			return
+		}
+		defer inactive.CleanUp()
+		if err = inactive.SetSnapLen(snaplen); err != nil {
+			err = fmt.Errorf("could not set snap length: %v", err)
+			return
+		} else if err = inactive.SetPromisc(promisc); err != nil {
+			err = fmt.Errorf("could not set promisc mode: %v", err)
+			return
+		} else if err = inactive.SetTimeout(time.Second); err != nil {
+			err = fmt.Errorf("could not set timeout: %v", err)
+			return
+		}
+		if tstype != "" {
+			var t pcap.TimestampSource
+			if t, err = pcap.TimestampSourceFromString(tstype); err != nil {
+				err = fmt.Errorf("supported timestamp types: %v", inactive.SupportedTimestamps())
+				return
+			} else if err = inactive.SetTimestampSource(t); err != nil {
+				err = fmt.Errorf("supported timestamp types: %v", inactive.SupportedTimestamps())
+				return
+			}
+		}
+		if capture, err = inactive.Activate(); err != nil {
+			err = fmt.Errorf("PCAP Activate error: %v", err)
+			return
+		}
+	}
+
+	handle = &pcapHandle{
+		capture: capture,
+	}
+
+	return
+}
diff --git a/tap/source/netns_packet_source.go b/tap/source/netns_packet_source.go
index 41f9968d0..c290e20d8 100644
--- a/tap/source/netns_packet_source.go
+++ b/tap/source/netns_packet_source.go
@@ -9,7 +9,7 @@ import (
 	"github.com/vishvananda/netns"
 )
 
-func newNetnsPacketSource(procfs string, pid string, interfaceName string,
+func newNetnsPacketSource(procfs string, pid string, interfaceName string, packetCapture string,
 	behaviour TcpPacketSourceBehaviour, origin api.Capture) (*tcpPacketSource, error) {
 	nsh, err := netns.GetFromPath(fmt.Sprintf("%s/%s/ns/net", procfs, pid))
 
@@ -18,7 +18,7 @@ func newNetnsPacketSource(procfs string, pid string, interfaceName string,
 		return nil, err
 	}
 
-	src, err := newPacketSourceFromNetnsHandle(pid, nsh, interfaceName, behaviour, origin)
+	src, err := newPacketSourceFromNetnsHandle(pid, nsh, interfaceName, packetCapture, behaviour, origin)
 
 	if err != nil {
 		logger.Log.Errorf("Error starting netns packet source for %s - %w", pid, err)
@@ -28,7 +28,7 @@ func newNetnsPacketSource(procfs string, pid string, interfaceName string,
 	return src, nil
 }
 
-func newPacketSourceFromNetnsHandle(pid string, nsh netns.NsHandle, interfaceName string,
+func newPacketSourceFromNetnsHandle(pid string, nsh netns.NsHandle, interfaceName string, packetCapture string,
 	behaviour TcpPacketSourceBehaviour, origin api.Capture) (*tcpPacketSource, error) {
 
 	done := make(chan *tcpPacketSource)
@@ -58,7 +58,7 @@ func newPacketSourceFromNetnsHandle(pid string, nsh netns.NsHandle, interfaceNam
 		}
 
 		name := fmt.Sprintf("netns-%s-%s", pid, interfaceName)
-		src, err := newTcpPacketSource(name, "", interfaceName, behaviour, origin)
+		src, err := newTcpPacketSource(name, "", interfaceName, packetCapture, behaviour, origin)
 
 		if err != nil {
 			logger.Log.Errorf("Error listening to PID %s - %w", pid, err)
diff --git a/tap/source/packet_source_manager.go b/tap/source/packet_source_manager.go
index 28ae85879..2274aec04 100644
--- a/tap/source/packet_source_manager.go
+++ b/tap/source/packet_source_manager.go
@@ -16,6 +16,7 @@ type PacketSourceManagerConfig struct {
 	mtls          bool
 	procfs        string
 	interfaceName string
+	packetCapture string
 	behaviour     TcpPacketSourceBehaviour
 }
 
@@ -25,8 +26,9 @@ type PacketSourceManager struct {
 }
 
 func NewPacketSourceManager(procfs string, filename string, interfaceName string,
-	mtls bool, pods []v1.Pod, behaviour TcpPacketSourceBehaviour, ipdefrag bool, packets chan<- TcpPacketInfo) (*PacketSourceManager, error) {
-	hostSource, err := newHostPacketSource(filename, interfaceName, behaviour)
+	mtls bool, pods []v1.Pod, behaviour TcpPacketSourceBehaviour, ipdefrag bool,
+	packetCapture string, packets chan<- TcpPacketInfo) (*PacketSourceManager, error) {
+	hostSource, err := newHostPacketSource(filename, interfaceName, packetCapture, behaviour)
 	if err != nil {
 		return nil, err
 	}
@@ -41,6 +43,7 @@ func NewPacketSourceManager(procfs string, filename string, interfaceName string
 		mtls:          mtls,
 		procfs:        procfs,
 		interfaceName: interfaceName,
+		packetCapture: packetCapture,
 		behaviour:     behaviour,
 	}
 
@@ -48,7 +51,7 @@ func NewPacketSourceManager(procfs string, filename string, interfaceName string
 	return sourceManager, nil
 }
 
-func newHostPacketSource(filename string, interfaceName string,
+func newHostPacketSource(filename string, interfaceName string, packetCapture string,
 	behaviour TcpPacketSourceBehaviour) (*tcpPacketSource, error) {
 	var name string
 	if filename == "" {
@@ -57,7 +60,7 @@ func newHostPacketSource(filename string, interfaceName string,
 		name = fmt.Sprintf("file-%s", filename)
 	}
 
-	source, err := newTcpPacketSource(name, filename, interfaceName, behaviour, api.Pcap)
+	source, err := newTcpPacketSource(name, filename, interfaceName, packetCapture, behaviour, api.Pcap)
 	if err != nil {
 		return nil, err
 	}
@@ -67,14 +70,14 @@ func newHostPacketSource(filename string, interfaceName string,
 
 func (m *PacketSourceManager) UpdatePods(pods []v1.Pod, ipdefrag bool, packets chan<- TcpPacketInfo) {
 	if m.config.mtls {
-		m.updateMtlsPods(m.config.procfs, pods, m.config.interfaceName, m.config.behaviour, ipdefrag, packets)
+		m.updateMtlsPods(m.config.procfs, pods, m.config.interfaceName, m.config.packetCapture, m.config.behaviour, ipdefrag, packets)
 	}
 
 	m.setBPFFilter(pods)
 }
 
 func (m *PacketSourceManager) updateMtlsPods(procfs string, pods []v1.Pod,
-	interfaceName string, behaviour TcpPacketSourceBehaviour, ipdefrag bool, packets chan<- TcpPacketInfo) {
+	interfaceName string, packetCapture string, behaviour TcpPacketSourceBehaviour, ipdefrag bool, packets chan<- TcpPacketInfo) {
 
 	relevantPids := m.getRelevantPids(procfs, pods)
 	logger.Log.Infof("Updating mtls pods (new: %v) (current: %v)", relevantPids, m.sources)
@@ -88,7 +91,7 @@ func (m *PacketSourceManager) updateMtlsPods(procfs string, pods []v1.Pod,
 
 	for pid, origin := range relevantPids {
 		if _, ok := m.sources[pid]; !ok {
-			source, err := newNetnsPacketSource(procfs, pid, interfaceName, behaviour, origin)
+			source, err := newNetnsPacketSource(procfs, pid, interfaceName, packetCapture, behaviour, origin)
 
 			if err == nil {
 				go source.readPackets(ipdefrag, packets)
@@ -165,12 +168,12 @@ func (m *PacketSourceManager) Stats() string {
 	result := ""
 
 	for _, source := range m.sources {
-		stats, err := source.Stats()
+		packetsReceived, packetsDropped, err := source.Stats()
 
 		if err != nil {
 			result = result + fmt.Sprintf("[%s: err:%s]", source.String(), err)
 		} else {
-			result = result + fmt.Sprintf("[%s: rec: %d dropped: %d]", source.String(), stats.PacketsReceived, stats.PacketsDropped)
+			result = result + fmt.Sprintf("[%s: rec: %d dropped: %d]", source.String(), packetsReceived, packetsDropped)
 		}
 	}
 
diff --git a/tap/source/tcp_packet_source.go b/tap/source/tcp_packet_source.go
index 70a41bc40..d24b53354 100644
--- a/tap/source/tcp_packet_source.go
+++ b/tap/source/tcp_packet_source.go
@@ -3,21 +3,27 @@ package source
 import (
 	"fmt"
 	"io"
-	"time"
 
 	"github.com/google/gopacket"
 	"github.com/google/gopacket/ip4defrag"
 	"github.com/google/gopacket/layers"
-	"github.com/google/gopacket/pcap"
 	"github.com/up9inc/mizu/logger"
 	"github.com/up9inc/mizu/tap/api"
 	"github.com/up9inc/mizu/tap/dbgctl"
 	"github.com/up9inc/mizu/tap/diagnose"
 )
 
+type Handle interface {
+	NextPacket() (packet gopacket.Packet, err error)
+	SetDecoder(decoder gopacket.Decoder, lazy bool, noCopy bool)
+	SetBPF(expr string) (err error)
+	LinkType() layers.LinkType
+	Stats() (packetsReceived uint, packetsDropped uint, err error)
+	Close() (err error)
+}
+
 type tcpPacketSource struct {
-	source    *gopacket.PacketSource
-	handle    *pcap.Handle
+	Handle    Handle
 	defragger *ip4defrag.IPv4Defragmenter
 	Behaviour *TcpPacketSourceBehaviour
 	name      string
@@ -25,12 +31,13 @@ type tcpPacketSource struct {
 }
 
 type TcpPacketSourceBehaviour struct {
-	SnapLength  int
-	Promisc     bool
-	Tstype      string
-	DecoderName string
-	Lazy        bool
-	BpfFilter   string
+	SnapLength   int
+	TargetSizeMb int
+	Promisc      bool
+	Tstype       string
+	DecoderName  string
+	Lazy         bool
+	BpfFilter    string
 }
 
 type TcpPacketInfo struct {
@@ -38,7 +45,7 @@ type TcpPacketInfo struct {
 	Source *tcpPacketSource
 }
 
-func newTcpPacketSource(name, filename string, interfaceName string,
+func newTcpPacketSource(name, filename string, interfaceName string, packetCapture string,
 	behaviour TcpPacketSourceBehaviour, origin api.Capture) (*tcpPacketSource, error) {
 	var err error
 
@@ -49,55 +56,47 @@ func newTcpPacketSource(name, filename string, interfaceName string,
 		Origin:    origin,
 	}
 
-	if filename != "" {
-		if result.handle, err = pcap.OpenOffline(filename); err != nil {
-			return result, fmt.Errorf("PCAP OpenOffline error: %v", err)
-		}
-	} else {
-		// This is a little complicated because we want to allow all possible options
-		// for creating the packet capture handle... instead of all this you can
-		// just call pcap.OpenLive if you want a simple handle.
-		inactive, err := pcap.NewInactiveHandle(interfaceName)
+	switch packetCapture {
+	case "af_packet":
+		result.Handle, err = newAfpacketHandle(
+			interfaceName,
+			behaviour.TargetSizeMb,
+			behaviour.SnapLength,
+		)
 		if err != nil {
-			return result, fmt.Errorf("could not create: %v", err)
+			return nil, err
 		}
-		defer inactive.CleanUp()
-		if err = inactive.SetSnapLen(behaviour.SnapLength); err != nil {
-			return result, fmt.Errorf("could not set snap length: %v", err)
-		} else if err = inactive.SetPromisc(behaviour.Promisc); err != nil {
-			return result, fmt.Errorf("could not set promisc mode: %v", err)
-		} else if err = inactive.SetTimeout(time.Second); err != nil {
-			return result, fmt.Errorf("could not set timeout: %v", err)
-		}
-		if behaviour.Tstype != "" {
-			if t, err := pcap.TimestampSourceFromString(behaviour.Tstype); err != nil {
-				return result, fmt.Errorf("supported timestamp types: %v", inactive.SupportedTimestamps())
-			} else if err := inactive.SetTimestampSource(t); err != nil {
-				return result, fmt.Errorf("supported timestamp types: %v", inactive.SupportedTimestamps())
-			}
-		}
-		if result.handle, err = inactive.Activate(); err != nil {
-			return result, fmt.Errorf("PCAP Activate error: %v", err)
-		}
-	}
-	if behaviour.BpfFilter != "" {
-		logger.Log.Infof("Using BPF filter %q", behaviour.BpfFilter)
-		if err = result.handle.SetBPFFilter(behaviour.BpfFilter); err != nil {
-			return nil, fmt.Errorf("BPF filter error: %v", err)
+		logger.Log.Infof("Using AF_PACKET socket as the capture source")
+	default:
+		result.Handle, err = newPcapHandle(
+			filename,
+			interfaceName,
+			behaviour.SnapLength,
+			behaviour.Promisc,
+			behaviour.Tstype,
+		)
+		if err != nil {
+			return nil, err
 		}
+		logger.Log.Infof("Using libpcap as the capture source")
 	}
 
-	var dec gopacket.Decoder
+	var decoder gopacket.Decoder
 	var ok bool
 	if behaviour.DecoderName == "" {
-		behaviour.DecoderName = result.handle.LinkType().String()
+		behaviour.DecoderName = result.Handle.LinkType().String()
 	}
-	if dec, ok = gopacket.DecodersByLayerName[behaviour.DecoderName]; !ok {
+	if decoder, ok = gopacket.DecodersByLayerName[behaviour.DecoderName]; !ok {
 		return nil, fmt.Errorf("no decoder named %v", behaviour.DecoderName)
 	}
-	result.source = gopacket.NewPacketSource(result.handle, dec)
-	result.source.Lazy = behaviour.Lazy
-	result.source.NoCopy = true
+	result.Handle.SetDecoder(decoder, behaviour.Lazy, true)
+
+	if behaviour.BpfFilter != "" {
+		logger.Log.Infof("Using BPF filter %q", behaviour.BpfFilter)
+		if err = result.setBPFFilter(behaviour.BpfFilter); err != nil {
+			return nil, fmt.Errorf("BPF filter error: %v", err)
+		}
+	}
 
 	return result, nil
 }
@@ -107,17 +106,17 @@ func (source *tcpPacketSource) String() string {
 }
 
 func (source *tcpPacketSource) setBPFFilter(expr string) (err error) {
-	return source.handle.SetBPFFilter(expr)
+	return source.Handle.SetBPF(expr)
 }
 
 func (source *tcpPacketSource) close() {
-	if source.handle != nil {
-		source.handle.Close()
+	if source.Handle != nil {
+		source.Handle.Close()
 	}
 }
 
-func (source *tcpPacketSource) Stats() (stat *pcap.Stats, err error) {
-	return source.handle.Stats()
+func (source *tcpPacketSource) Stats() (packetsReceived uint, packetsDropped uint, err error) {
+	return source.Handle.Stats()
 }
 
 func (source *tcpPacketSource) readPackets(ipdefrag bool, packets chan<- TcpPacketInfo) {
@@ -127,7 +126,7 @@ func (source *tcpPacketSource) readPackets(ipdefrag bool, packets chan<- TcpPack
 	logger.Log.Infof("Start reading packets from %v", source.name)
 
 	for {
-		packet, err := source.source.NextPacket()
+		packet, err := source.Handle.NextPacket()
 
 		if err == io.EOF {
 			logger.Log.Infof("Got EOF while reading packets from %v", source.name)
-- 
GitLab