From f1b913c969cdaade8ff505116704140ba8f0b76f Mon Sep 17 00:00:00 2001 From: Leon Hwang Date: Sun, 24 Dec 2023 16:31:08 +0800 Subject: [PATCH] Fix an output bug When there is no packet data, it will panic. Meanwhile, refactor code of checking `sock_path` of Unix socket. And, add some comments. Signed-off-by: Leon Hwang --- bpf/sockdump.c | 44 ++++++++++++++++++++++--------------- internal/sockdump/output.go | 12 ++++++---- 2 files changed, 34 insertions(+), 22 deletions(-) diff --git a/bpf/sockdump.c b/bpf/sockdump.c index a0db612..ccb355e 100644 --- a/bpf/sockdump.c +++ b/bpf/sockdump.c @@ -68,10 +68,9 @@ __is_kernel_ge_6_0_0(void) static __always_inline bool __is_str_prefix(const char *str, const char *prefix, int siz) { - for (int i = 0; i < siz && prefix[i]; i++) { + for (int i = 0; i < siz && prefix[i]; i++) if (str[i] != prefix[i]) return false; - } return true; } @@ -82,11 +81,13 @@ __is_path_matched(__u64 *path) __u64 *sock_path = (__u64 *) cfg->sock_path; int i; - for (i = 0; i < UNIX_PATH_MAX / 8 && sock_path[i]; i++) { + // 1. Use __u64 to reduce iterations. + // 2. Use __is_str_prefix() to match the prefix of the path. + + for (i = 0; i < UNIX_PATH_MAX / 8 && sock_path[i]; i++) if (path[i] != sock_path[i]) return __is_str_prefix((const char *) &path[i], (const char *) &sock_path[i], 8); - } if (i == UNIX_PATH_MAX / 8) return __is_str_prefix((const char *) &path[i], @@ -102,20 +103,31 @@ match_path_of_usk(struct unix_sock *usk, __u64 *path) __u8 one_byte = 0; char *sock_path; + // Skip current capture if addr->len is zero. + addr = BPF_CORE_READ(usk, addr); if (!BPF_CORE_READ(addr, len)) return false; + // 1. Use offset instead of BPF_CORE_READ() to get the address of the path. + // 2. Check if it's "@/path/to/unix.sock". + sock_path = (char *) addr + SOCK_PATH_OFFSET; bpf_probe_read_kernel(&one_byte, 1, sock_path); - if (one_byte == 0) - bpf_probe_read_kernel_str(path, UNIX_PATH_MAX, sock_path + 1); - else + if (one_byte) bpf_probe_read_kernel_str(path, UNIX_PATH_MAX, sock_path); + else + bpf_probe_read_kernel_str(path, UNIX_PATH_MAX, sock_path + 1); return __is_path_matched(path); } +static __always_inline bool +__is_sock_path_matched(struct unix_sock *usk, __u64 *path) +{ + return usk && match_path_of_usk(usk, path); +} + static __always_inline void collect_data(void *ctx, struct packet *pkt, char *buf, __u32 len) { @@ -124,6 +136,9 @@ collect_data(void *ctx, struct packet *pkt, char *buf, __u32 len) pkt->flags = 0; pkt->len = len; + // It's necessary to check the maximum size of the segment. Otherwise, the + // verifier will complain about the out-of-bound access. + n = len > seg_size ? seg_size : len; if (n < SS_MAX_SEG_SIZE) bpf_probe_read(&pkt->data, n, buf); @@ -137,9 +152,9 @@ collect_data(void *ctx, struct packet *pkt, char *buf, __u32 len) static __noinline int __usk_sendmsg(void *ctx, struct socket *sock, struct msghdr *msg, size_t len) { + struct unix_sock *usk, *peer; const struct iovec *iov; struct upid numbers[1]; - struct unix_sock *usk; struct iov_iter *iter; struct packet *pkt; __u64 *path, nsegs; @@ -157,18 +172,11 @@ __usk_sendmsg(void *ctx, struct socket *sock, struct msghdr *msg, size_t len) path = (__u64 *) pkt->path; usk = bpf_skc_to_unix_sock(sock->sk); - if (!usk) + peer = usk ? bpf_skc_to_unix_sock(usk->peer) : NULL; + if (!__is_sock_path_matched(usk, path) && + !__is_sock_path_matched(peer, path)) return 0; - if (!match_path_of_usk(usk, path)) { - usk = bpf_skc_to_unix_sock(usk->peer); - if (!usk) - return 0; - - if (!match_path_of_usk(usk, path)) - return 0; - } - pkt->pid = pid; bpf_get_current_comm(&pkt->comm, sizeof(pkt->comm)); BPF_CORE_READ_INTO(&numbers, sock, sk, sk_peer_pid, numbers); diff --git a/internal/sockdump/output.go b/internal/sockdump/output.go index 2792517..77bb8cd 100644 --- a/internal/sockdump/output.go +++ b/internal/sockdump/output.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "encoding/hex" "fmt" + "log" "os" "time" "unsafe" @@ -85,7 +86,7 @@ func (o *Output) outputString(pkt *Packet, data []byte) { } if pkt.Flags != 0 { - fmt.Fprintf(o.w, "error") + fmt.Fprintln(o.w, "error") } else { fmt.Fprintf(o.w, "%s\n", nullTerminatedString(data)) } @@ -98,7 +99,7 @@ func (o *Output) outputHex(pkt *Packet, data []byte) { } if pkt.Flags != 0 { - fmt.Fprintf(o.w, "error") + fmt.Fprintln(o.w, "error") } else { fmt.Fprintln(o.w, hex.Dump(data)) } @@ -111,7 +112,7 @@ func (o *Output) outputHexString(pkt *Packet, data []byte) { } if pkt.Flags != 0 { - fmt.Fprintf(o.w, "error") + fmt.Fprintln(o.w, "error") } else { fmt.Fprintf(o.w, "%s\n", hex.EncodeToString(data)) } @@ -137,7 +138,7 @@ func (o *Output) outputPcap(pkt *Packet, data []byte) { Length: len(data), InterfaceIndex: 0, }, data); err != nil { - fmt.Fprintf(os.Stderr, "failed to write pcap packet: %v", err) + log.Printf("failed to write pcap packet: %v", err) } } @@ -152,5 +153,8 @@ func nullTerminated(s []byte) []byte { func nullTerminatedString(s []byte) string { s = nullTerminated(s) + if len(s) == 0 { + return "" + } return unsafe.String(&s[0], len(s)) }