Skip to content

Instantly share code, notes, and snippets.

@asmichi
Last active January 30, 2021 13:47
Show Gist options
  • Save asmichi/80e65892aa68409ea785a5ca74c8e3d0 to your computer and use it in GitHub Desktop.
Save asmichi/80e65892aa68409ea785a5ca74c8e3d0 to your computer and use it in GitHub Desktop.
Sending file descriptors through domain sockets
#include <cassert>
#include <cstdio>
#include <cstdio>
#include <cstring>
#include <sys/socket.h>
#include <unistd.h>
constexpr int SocketMaxFdsPerCall = 1;
struct CmsgFds
{
static const constexpr std::size_t BufferSize = CMSG_SPACE(sizeof(int) * SocketMaxFdsPerCall);
alignas(cmsghdr) char Buffer[BufferSize];
};
ssize_t SendWithFd(int fd, const void* buf, std::size_t len, const int* fds, std::size_t fdCount);
ssize_t RecvWithFd(int fd, void* buf, std::size_t len);
int main()
{
int socks[2];
if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0, socks) != 0)
{
perror("socketpair");
return 1;
}
char buf[4]{};
int fds[1] = { socks[0] };
ssize_t sentBytes = SendWithFd(socks[0], buf, sizeof(buf), fds, 1);
if (sentBytes < 0)
{
perror("sendmsg");
}
for (int i = 0; i < 4; i++)
{
ssize_t receivedBytes = RecvWithFd(socks[1], buf, 1);
if (receivedBytes < 0)
{
perror("recvmsg");
}
}
}
ssize_t SendWithFd(int fd, const void* buf, std::size_t len, const int* fds, std::size_t fdCount)
{
assert(fdCount <= SocketMaxFdsPerCall);
iovec iov{};
msghdr msg{};
CmsgFds cmsgFds{};
iov.iov_base = const_cast<void*>(buf);
iov.iov_len = len;
msg.msg_name = nullptr;
msg.msg_namelen = 0;
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = cmsgFds.Buffer;
msg.msg_controllen = CmsgFds::BufferSize;
msg.msg_flags = 0;
struct cmsghdr* pcmsghdr = CMSG_FIRSTHDR(&msg);
pcmsghdr->cmsg_len = CMSG_LEN(sizeof(int) * fdCount);
pcmsghdr->cmsg_level = SOL_SOCKET;
pcmsghdr->cmsg_type = SCM_RIGHTS;
std::memcpy(CMSG_DATA(pcmsghdr), fds, sizeof(int) * fdCount);
std::printf("Sending %zu fd(s)...\n", fdCount);
return sendmsg(fd, &msg, 0);
}
ssize_t RecvWithFd(int fd, void* buf, std::size_t len)
{
iovec iov{};
msghdr msg{};
CmsgFds cmsgFds{};
iov.iov_base = buf;
iov.iov_len = len;
msg.msg_name = nullptr;
msg.msg_namelen = 0;
msg.msg_iov = &iov;
msg.msg_iovlen = 1;
msg.msg_control = cmsgFds.Buffer;
msg.msg_controllen = CmsgFds::BufferSize;
msg.msg_flags = 0;
const ssize_t receivedBytes = recvmsg(fd, &msg, 0 | MSG_CMSG_CLOEXEC);
if (receivedBytes == -1)
{
return -1;
}
for (cmsghdr* pcmsghdr = CMSG_FIRSTHDR(&msg); pcmsghdr != nullptr; pcmsghdr = CMSG_NXTHDR(&msg, pcmsghdr))
{
assert(pcmsghdr->cmsg_level == SOL_SOCKET && pcmsghdr->cmsg_type == SCM_RIGHTS);
unsigned char* const cmsgdata = CMSG_DATA(pcmsghdr);
const std::ptrdiff_t cmsgdataLen = pcmsghdr->cmsg_len - (cmsgdata - reinterpret_cast<unsigned char*>(pcmsghdr));
const std::size_t fdCount = cmsgdataLen / sizeof(int);
for (std::size_t i = 0; i < fdCount; i++)
{
int receivedFd;
std::memcpy(&receivedFd, cmsgdata + sizeof(int) * i, sizeof(int));
std::printf("Received fd %d\n", receivedFd);
close(receivedFd);
}
}
return receivedBytes;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment