Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 60 additions & 8 deletions audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
// Netlink groups.
const (
NetlinkGroupNone = iota // Group 0 not used
NetlinkGroupReadLog // "best effort" read only socket
NetlinkGroupReadLog // "best effort" read only socket, defined in the kernel as AUDIT_NLGRP_READLOG
)

// WaitMode is a flag to control the behavior of methods that abstract
Expand Down Expand Up @@ -427,16 +427,11 @@
// become no-ops.
func (c *AuditClient) Close() error {
var err error

// Only unregister and close the socket once.
c.closeOnce.Do(func() {
if c.clearPIDOnClose {
// Unregister from the kernel for a clean exit.
status := AuditStatus{
Mask: AuditStatusPID,
PID: 0,
}
err = c.set(status, NoWait)
err = c.closeAndUnsetPid()
}

err = errors.Join(err, c.Netlink.Close())
Expand Down Expand Up @@ -505,6 +500,63 @@
return &msg, nil
}

// unset our pid from the audit subsystem and close the socket.
// This is a sort of isolated refactor, meant to deal with the deadlocks that can happen when we're not careful with blocking operations throughout a lot of this code.
func (c *AuditClient) closeAndUnsetPid() error {
msg := syscall.NetlinkMessage{
Header: syscall.NlMsghdr{
Type: AuditSet,
Flags: syscall.NLM_F_REQUEST,
},
Data: AuditStatus{
Mask: AuditStatusPID,
PID: 0,
}.toWireFormat(),
}

noParse := func(bytes []byte) ([]syscall.NetlinkMessage, error) {
return nil, nil
}

// If our request to unset the PID would block, then try to drain events from
// the netlink socket, resend, try again.
// In netlink, EAGAIN usually indicates our read buffer is full.
// The auditd code (which I'm using as a reference implementation) doesn't wait for a response when unsetting the audit pid.
maxLoop := 5
for i := 0; i < maxLoop; i++ {
_, err := c.Netlink.SendNoWait(msg)
// if we get an interrupt, retry the send
if err == nil {
return nil
} else if errors.Is(err, syscall.EINTR) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(*NetlinkClient).SendNoWait is a thin wrapper around syscall.SendTo. No functions in syscall ever wrap errors and given go1compat, this will not change (go1compat says that syscall is not covered, but this relates to changes forced by changes in the underlying OS and any change to types returned by syscalls would neither be likely from the Go team, nor quietly accepted by external reviewer).

tl;dr; No errors.Is is required here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so, even if syscall.* isn't going to be wrapping anything, I feel like it's safer to assume that c.Netlink.* might wrap something in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, it shouldn't, but it feels a little safer to be defensive. Also the linter complains if you don't use errors.Is.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My preferred approach would be to ensure that the behaviour is covered by tests, and then leave it without the extra weight. This would find regressions and ensure that bloat is not added to SendNoWait. I think the linter is wrong here. However, I'm not an owner, so it doesn't really matter.

// got interrupt, try again
continue
} else if errors.Is(err, syscall.EAGAIN) {

Check warning on line 534 in audit.go

View workflow job for this annotation

GitHub Actions / lint

early-return: if c { ... } else { ... return } can be simplified to if !c { ... return } ... (revive)
maxRecv := 10000
// send would block, try to drain the receive socket
for i := 0; i < maxRecv; i++ {
_, err = c.Netlink.Receive(true, noParse)
if errors.Is(err, syscall.EAGAIN) {
// receive would block, try to send again
break
} else if err == nil || errors.Is(err, syscall.EINTR) || errors.Is(err, syscall.ENOBUFS) {
// retry the receive
continue
} else {
// if we have another kind of error, just bail and return that error.
return err
}
}
} else {
// if we get another error from the send, return that up
return err
}

}
// we may not want to treat this as a hard error?
return fmt.Errorf("could not unset pid from audit after retries")
}

func (c *AuditClient) set(status AuditStatus, mode WaitMode) error {
msg := syscall.NetlinkMessage{
Header: syscall.NlMsghdr{
Expand Down Expand Up @@ -560,7 +612,7 @@
// https://github.com/linux-audit/audit-kernel/blob/v4.7/include/uapi/linux/audit.h#L318-L325
type AuditStatusMask uint32

// Mask types for AuditStatus.
// Mask types for AuditStatus. Originally defined in the kernel at audit.h
const (
AuditStatusEnabled AuditStatusMask = 1 << iota
AuditStatusFailure
Expand Down
12 changes: 11 additions & 1 deletion netlink.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
// in the message and an error if it occurred.
type NetlinkSender interface {
Send(msg syscall.NetlinkMessage) (uint32, error)
SendNoWait(msg syscall.NetlinkMessage) (uint32, error)
}

// NetlinkReceiver receives data from the netlink socket and uses the provided
Expand Down Expand Up @@ -126,17 +127,26 @@ func getPortID(fd int) (uint32, error) {
return addr.Pid, nil
}

// SendNoWait sends a message to the netlink client in non-blocking mode. Behavior is otherwise identical to Send()
func (c *NetlinkClient) SendNoWait(msg syscall.NetlinkMessage) (uint32, error) {
return c.send(msg, syscall.MSG_DONTWAIT)
}

// Send sends a netlink message and returns the sequence number used
// in the message and an error if it occurred. If the PID is not set then
// the value will be populated automatically (recommended).
func (c *NetlinkClient) Send(msg syscall.NetlinkMessage) (uint32, error) {
return c.send(msg, 0)
}

func (c *NetlinkClient) send(msg syscall.NetlinkMessage, flags int) (uint32, error) {
if msg.Header.Pid == 0 {
msg.Header.Pid = c.pid
}

msg.Header.Seq = atomic.AddUint32(&c.seq, 1)
to := &syscall.SockaddrNetlink{}
return msg.Header.Seq, syscall.Sendto(c.fd, serialize(msg), 0, to)
return msg.Header.Seq, syscall.Sendto(c.fd, serialize(msg), flags, to)
}

func serialize(msg syscall.NetlinkMessage) []byte {
Expand Down
Loading