Skip to content

Commit

Permalink
Fix an output bug
Browse files Browse the repository at this point in the history
When there is no packet data, it will panic.

Meanwhile, refactor code of checking `sock_path` of Unix socket.

And, add some comments.

Signed-off-by: Leon Hwang <hffilwlqm@gmail.com>
  • Loading branch information
Asphaltt committed Dec 24, 2023
1 parent 6d5ef76 commit f1b913c
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 22 deletions.
44 changes: 26 additions & 18 deletions bpf/sockdump.c
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,9 @@ __is_kernel_ge_6_0_0(void)
static __always_inline bool
__is_str_prefix(const char *str, const char *prefix, int siz)
{
for (int i = 0; i < siz && prefix[i]; i++) {
for (int i = 0; i < siz && prefix[i]; i++)
if (str[i] != prefix[i])
return false;
}

return true;
}
Expand All @@ -82,11 +81,13 @@ __is_path_matched(__u64 *path)
__u64 *sock_path = (__u64 *) cfg->sock_path;
int i;

for (i = 0; i < UNIX_PATH_MAX / 8 && sock_path[i]; i++) {
// 1. Use __u64 to reduce iterations.
// 2. Use __is_str_prefix() to match the prefix of the path.

for (i = 0; i < UNIX_PATH_MAX / 8 && sock_path[i]; i++)
if (path[i] != sock_path[i])
return __is_str_prefix((const char *) &path[i],
(const char *) &sock_path[i], 8);
}

if (i == UNIX_PATH_MAX / 8)
return __is_str_prefix((const char *) &path[i],
Expand All @@ -102,20 +103,31 @@ match_path_of_usk(struct unix_sock *usk, __u64 *path)
__u8 one_byte = 0;
char *sock_path;

// Skip current capture if addr->len is zero.

addr = BPF_CORE_READ(usk, addr);
if (!BPF_CORE_READ(addr, len))
return false;

// 1. Use offset instead of BPF_CORE_READ() to get the address of the path.
// 2. Check if it's "@/path/to/unix.sock".

sock_path = (char *) addr + SOCK_PATH_OFFSET;
bpf_probe_read_kernel(&one_byte, 1, sock_path);
if (one_byte == 0)
bpf_probe_read_kernel_str(path, UNIX_PATH_MAX, sock_path + 1);
else
if (one_byte)
bpf_probe_read_kernel_str(path, UNIX_PATH_MAX, sock_path);
else
bpf_probe_read_kernel_str(path, UNIX_PATH_MAX, sock_path + 1);

return __is_path_matched(path);
}

static __always_inline bool
__is_sock_path_matched(struct unix_sock *usk, __u64 *path)
{
return usk && match_path_of_usk(usk, path);
}

static __always_inline void
collect_data(void *ctx, struct packet *pkt, char *buf, __u32 len)
{
Expand All @@ -124,6 +136,9 @@ collect_data(void *ctx, struct packet *pkt, char *buf, __u32 len)
pkt->flags = 0;
pkt->len = len;

// It's necessary to check the maximum size of the segment. Otherwise, the
// verifier will complain about the out-of-bound access.

n = len > seg_size ? seg_size : len;
if (n < SS_MAX_SEG_SIZE)
bpf_probe_read(&pkt->data, n, buf);
Expand All @@ -137,9 +152,9 @@ collect_data(void *ctx, struct packet *pkt, char *buf, __u32 len)
static __noinline int
__usk_sendmsg(void *ctx, struct socket *sock, struct msghdr *msg, size_t len)
{
struct unix_sock *usk, *peer;
const struct iovec *iov;
struct upid numbers[1];
struct unix_sock *usk;
struct iov_iter *iter;
struct packet *pkt;
__u64 *path, nsegs;
Expand All @@ -157,18 +172,11 @@ __usk_sendmsg(void *ctx, struct socket *sock, struct msghdr *msg, size_t len)
path = (__u64 *) pkt->path;

usk = bpf_skc_to_unix_sock(sock->sk);
if (!usk)
peer = usk ? bpf_skc_to_unix_sock(usk->peer) : NULL;
if (!__is_sock_path_matched(usk, path) &&
!__is_sock_path_matched(peer, path))
return 0;

if (!match_path_of_usk(usk, path)) {
usk = bpf_skc_to_unix_sock(usk->peer);
if (!usk)
return 0;

if (!match_path_of_usk(usk, path))
return 0;
}

pkt->pid = pid;
bpf_get_current_comm(&pkt->comm, sizeof(pkt->comm));
BPF_CORE_READ_INTO(&numbers, sock, sk, sk_peer_pid, numbers);
Expand Down
12 changes: 8 additions & 4 deletions internal/sockdump/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"encoding/binary"
"encoding/hex"
"fmt"
"log"
"os"
"time"
"unsafe"
Expand Down Expand Up @@ -85,7 +86,7 @@ func (o *Output) outputString(pkt *Packet, data []byte) {
}

if pkt.Flags != 0 {
fmt.Fprintf(o.w, "error")
fmt.Fprintln(o.w, "error")
} else {
fmt.Fprintf(o.w, "%s\n", nullTerminatedString(data))
}
Expand All @@ -98,7 +99,7 @@ func (o *Output) outputHex(pkt *Packet, data []byte) {
}

if pkt.Flags != 0 {
fmt.Fprintf(o.w, "error")
fmt.Fprintln(o.w, "error")
} else {
fmt.Fprintln(o.w, hex.Dump(data))
}
Expand All @@ -111,7 +112,7 @@ func (o *Output) outputHexString(pkt *Packet, data []byte) {
}

if pkt.Flags != 0 {
fmt.Fprintf(o.w, "error")
fmt.Fprintln(o.w, "error")
} else {
fmt.Fprintf(o.w, "%s\n", hex.EncodeToString(data))
}
Expand All @@ -137,7 +138,7 @@ func (o *Output) outputPcap(pkt *Packet, data []byte) {
Length: len(data),
InterfaceIndex: 0,
}, data); err != nil {
fmt.Fprintf(os.Stderr, "failed to write pcap packet: %v", err)
log.Printf("failed to write pcap packet: %v", err)
}
}

Expand All @@ -152,5 +153,8 @@ func nullTerminated(s []byte) []byte {

func nullTerminatedString(s []byte) string {
s = nullTerminated(s)
if len(s) == 0 {
return ""
}
return unsafe.String(&s[0], len(s))
}

0 comments on commit f1b913c

Please sign in to comment.