Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 67 additions & 8 deletions audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ const (
// Netlink groups.
const (
NetlinkGroupNone = iota // Group 0 not used
NetlinkGroupReadLog // "best effort" read only socket
NetlinkGroupReadLog // "best effort" read only socket, defined in the kernel as AUDIT_NLGRP_READLOG
)

// WaitMode is a flag to control the behavior of methods that abstract
Expand Down Expand Up @@ -427,16 +427,11 @@ func (c *AuditClient) Receive(nonBlocking bool) (*RawAuditMessage, error) {
// become no-ops.
func (c *AuditClient) Close() error {
var err error

// Only unregister and close the socket once.
c.closeOnce.Do(func() {
if c.clearPIDOnClose {
// Unregister from the kernel for a clean exit.
status := AuditStatus{
Mask: AuditStatusPID,
PID: 0,
}
err = c.set(status, NoWait)
err = c.closeAndUnsetPid()
}

err = errors.Join(err, c.Netlink.Close())
Expand Down Expand Up @@ -505,6 +500,70 @@ func (c *AuditClient) getReply(seq uint32) (*syscall.NetlinkMessage, error) {
return &msg, nil
}

// unset our pid from the audit subsystem and close the socket.
// This is a sort of isolated refactor, meant to deal with the deadlocks that can happen when we're not careful with blocking operations throughout a lot of this code.
func (c *AuditClient) closeAndUnsetPid() error {
msg := syscall.NetlinkMessage{
Header: syscall.NlMsghdr{
Type: AuditSet,
Flags: syscall.NLM_F_REQUEST,
},
Data: AuditStatus{
Mask: AuditStatusPID,
PID: 0,
}.toWireFormat(),
}

// If our request to unset the PID would block, then try to drain events from
// the netlink socket, resend, try again.
// In netlink, EAGAIN usually indicates our read buffer is full.
// The auditd code (which I'm using as a reference implementation) doesn't wait for a response when unsetting the audit pid.
// The retry count here is largely arbitrary, and provides a buffer for either transient errors (EINTR) or retries.
retries := 5
outer:
for i := 0; i < retries; i++ {
_, err := c.Netlink.SendNoWait(msg)
switch {
case err == nil:
return nil
case errors.Is(err, syscall.EINTR):
// got a transient interrupt, try again
continue
case errors.Is(err, syscall.EAGAIN):
// send would block, try to drain the receive socket. The recv count here is just so we have enough of a buffer to attempt a send again/
// The number is just here so we ideally have enough of a buffer to attempt the send again.
maxRecv := 10000
for i := 0; i < maxRecv; i++ {
_, err = c.Netlink.Receive(true, noParse)
switch {
case err == nil, errors.Is(err, syscall.EINTR), errors.Is(err, syscall.ENOBUFS):
// continue with receive, try to read more data
continue
case errors.Is(err, syscall.EAGAIN):
// receive would block, try to send again
continue outer
default:
// if receive returns an other error, just return that.
return err
}
}
default:
// if Send returns and other error, just return that
return err
}

}
// we may not want to treat this as a hard error?
// It's not a massive error if this fails, since the kernel will unset the PID if it can't communicate with the process,
// so this is largely for neatness.
return fmt.Errorf("could not unset pid from audit after retries")
}

// noParse is a no-op parser used by closeAndUnsetPID
func noParse([]byte) ([]syscall.NetlinkMessage, error) {
return nil, nil
}

func (c *AuditClient) set(status AuditStatus, mode WaitMode) error {
msg := syscall.NetlinkMessage{
Header: syscall.NlMsghdr{
Expand Down Expand Up @@ -560,7 +619,7 @@ func parseNetlinkAuditMessage(buf []byte) ([]syscall.NetlinkMessage, error) {
// https://github.com/linux-audit/audit-kernel/blob/v4.7/include/uapi/linux/audit.h#L318-L325
type AuditStatusMask uint32

// Mask types for AuditStatus.
// Mask types for AuditStatus. Originally defined in the kernel at audit.h
const (
AuditStatusEnabled AuditStatusMask = 1 << iota
AuditStatusFailure
Expand Down
104 changes: 104 additions & 0 deletions audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@ import (
"io"
"os"
"runtime"
"slices"
"sync"
"syscall"
"testing"
"testing/quick"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/elastic/go-libaudit/v2/rule"
"github.com/elastic/go-libaudit/v2/rule/flags"
Expand All @@ -55,6 +58,107 @@ var (
// -a always,exit -S open,truncate -F dir=/etc -F success=0
const testRule = `BAAAAAIAAAACAAAABAAAAAAAAAAAEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAGsAAABoAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQAAAAEAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAQAAAAvZXRj`

// TestNetlinkIface is a mock interface for testing close behavior
type TestNetlinkIface struct {
recvStack []error
sendStack []error
}

func (*TestNetlinkIface) Close() error {
return nil
}

func (tn *TestNetlinkIface) Send(_ syscall.NetlinkMessage) (uint32, error) {
top := tn.sendStack[0]
tn.sendStack = slices.Delete(tn.sendStack, 0, 1)
return 0, top
}

func (tn *TestNetlinkIface) SendNoWait(_ syscall.NetlinkMessage) (uint32, error) {
top := tn.sendStack[0]
tn.sendStack = slices.Delete(tn.sendStack, 0, 1)
return 0, top
}

func (tn *TestNetlinkIface) Receive(_ bool, _ NetlinkParser) ([]syscall.NetlinkMessage, error) {
top := tn.recvStack[0]
tn.recvStack = slices.Delete(tn.recvStack, 0, 1)
return nil, top
}

func TestCloseBehavior(t *testing.T) {
testCases := []struct {
name string
cfg *TestNetlinkIface
err error
}{
{
name: "retry",
cfg: &TestNetlinkIface{
// cause the first send to error out
sendStack: []error{syscall.EWOULDBLOCK, nil, nil},
// force the close logic to drain
recvStack: []error{syscall.ENOBUFS, syscall.ENOBUFS, syscall.EAGAIN},
},
err: nil,
},
{
name: "repeated-send-fail",
cfg: &TestNetlinkIface{
// cause the first send to error out
sendStack: []error{syscall.EWOULDBLOCK, syscall.EWOULDBLOCK, syscall.EWOULDBLOCK, nil},
// force the close logic to drain
recvStack: []error{syscall.EWOULDBLOCK, syscall.EWOULDBLOCK, nil, syscall.EWOULDBLOCK, nil, syscall.EWOULDBLOCK, nil},
},
err: nil,
},
{
name: "transient-eintr-send",
cfg: &TestNetlinkIface{
// cause the first send to error out
sendStack: []error{syscall.EINTR, nil, nil},
// force the close logic to drain
recvStack: []error{syscall.EAGAIN},
},
err: nil,
},
{
name: "fail-recv-error",
cfg: &TestNetlinkIface{
// cause the first send to error out
sendStack: []error{syscall.EWOULDBLOCK, nil, nil},
// force the close logic to drain
recvStack: []error{syscall.ENOBUFS, syscall.ENOBUFS, syscall.EBADFD},
},
err: syscall.EBADFD,
},
{
name: "fail-send-error",
cfg: &TestNetlinkIface{
// cause the first send to error out
sendStack: []error{syscall.EWOULDBLOCK, syscall.EBADFD, nil},
// force the close logic to drain
recvStack: []error{syscall.EAGAIN, syscall.EAGAIN},
},
err: syscall.EBADFD,
},
}

for _, test := range testCases {
t.Run(test.name, func(t *testing.T) {
testClient := AuditClient{
Netlink: test.cfg,
pendingAcks: []uint32{},
clearPIDOnClose: true,
closeOnce: sync.Once{},
}

err := testClient.Close()
require.True(t, errors.Is(err, test.err), "expected error %s", test.err)
})
}
}

func TestAuditClientGetStatus(t *testing.T) {
if os.Geteuid() != 0 {
t.Skip("must be root to get audit status")
Expand Down
12 changes: 11 additions & 1 deletion netlink.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
// in the message and an error if it occurred.
type NetlinkSender interface {
Send(msg syscall.NetlinkMessage) (uint32, error)
SendNoWait(msg syscall.NetlinkMessage) (uint32, error)
}

// NetlinkReceiver receives data from the netlink socket and uses the provided
Expand Down Expand Up @@ -126,17 +127,26 @@ func getPortID(fd int) (uint32, error) {
return addr.Pid, nil
}

// SendNoWait sends a message to the netlink client in non-blocking mode. Behavior is otherwise identical to Send()
func (c *NetlinkClient) SendNoWait(msg syscall.NetlinkMessage) (uint32, error) {
return c.send(msg, syscall.MSG_DONTWAIT)
}

// Send sends a netlink message and returns the sequence number used
// in the message and an error if it occurred. If the PID is not set then
// the value will be populated automatically (recommended).
func (c *NetlinkClient) Send(msg syscall.NetlinkMessage) (uint32, error) {
return c.send(msg, 0)
}

func (c *NetlinkClient) send(msg syscall.NetlinkMessage, flags int) (uint32, error) {
if msg.Header.Pid == 0 {
msg.Header.Pid = c.pid
}

msg.Header.Seq = atomic.AddUint32(&c.seq, 1)
to := &syscall.SockaddrNetlink{}
return msg.Header.Seq, syscall.Sendto(c.fd, serialize(msg), 0, to)
return msg.Header.Seq, syscall.Sendto(c.fd, serialize(msg), flags, to)
}

func serialize(msg syscall.NetlinkMessage) []byte {
Expand Down
Loading