Skip to content

Instantly share code, notes, and snippets.

@cloudhan
Last active March 31, 2021 08:25
Show Gist options
  • Save cloudhan/9cfd12fc0c89ac0f4efce5270a38c62e to your computer and use it in GitHub Desktop.
Save cloudhan/9cfd12fc0c89ac0f4efce5270a38c62e to your computer and use it in GitHub Desktop.
/*************************************************************************
* Copyright (c) 2016-2019, NVIDIA CORPORATION. All rights reserved.
*
* See LICENSE.txt for license information
************************************************************************/
#include <arpa/inet.h>
#include <errno.h>
#include <ifaddrs.h>
#include <net/if.h>
#include <netdb.h>
#include <netinet/tcp.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <unistd.h>
#define ENABLE_TRACE 1
#define ncclDebugNoWarn 0
typedef int ncclResult_t;
#define ncclSuccess 0
#define ncclSystemError 1
#define ncclInvalidArgument 2
struct netIf {
char prefix[64];
int port;
};
#define TRACE(place_holder, pattern, ...) \
do { \
printf("TRACE: "); \
printf(pattern, __VA_ARGS__); \
printf("\n"); \
} while (0)
#define INFO(place_holder, pattern, ...) \
do { \
printf("INFO: "); \
printf(pattern, __VA_ARGS__); \
printf("\n"); \
} while (0)
#define WARN(...) \
do { \
printf("WARN: "); \
printf(__VA_ARGS__); \
printf("\n"); \
} while (0)
#define SYSCHECK(call, name) \
do { \
int retval; \
SYSCHECKVAL(call, name, retval); \
} while (false)
#define SYSCHECKVAL(call, name, retval) \
do { \
SYSCHECKSYNC(call, name, retval); \
if (retval == -1) { \
WARN("Call to " name " failed : %s", strerror(errno)); \
return ncclSystemError; \
} \
} while (false)
#define SYSCHECKSYNC(call, name, retval) \
do { \
retval = call; \
if (retval == -1 && (errno == EINTR || errno == EWOULDBLOCK || errno == EAGAIN)) { \
INFO(NCCL_ALL, "Call to " name " returned %s, retrying", strerror(errno)); \
} else { \
break; \
} \
} while (true)
// Propagate errors up
#define NCCLCHECK(call) \
do { \
ncclResult_t res = call; \
if (res != ncclSuccess) { \
/* Print the back trace*/ \
if (ncclDebugNoWarn == 0) INFO(NCCL_ALL, "%s:%d -> %d", __FILE__, __LINE__, res); \
return res; \
} \
} while (0);
#define MAX_IFS 16
#define MAX_IF_NAME_SIZE 16
#define SLEEP_INT 1000 // connection retry sleep interval in usec
#define RETRY_REFUSED_TIMES 2e4 // connection refused retry times before reporting a timeout (20 sec)
#define RETRY_TIMEDOUT_TIMES 3 // connection timed out retry times (each one can take 20s)
int parseStringList(const char* string, struct netIf* ifList, int maxList) {
if (!string) return 0;
const char* ptr = string;
int ifNum = 0;
int ifC = 0;
char c;
do {
c = *ptr;
if (c == ':') {
if (ifC > 0) {
ifList[ifNum].prefix[ifC] = '\0';
ifList[ifNum].port = atoi(ptr + 1);
ifNum++;
ifC = 0;
}
while (c != ',' && c != '\0') c = *(++ptr);
} else if (c == ',' || c == '\0') {
if (ifC > 0) {
ifList[ifNum].prefix[ifC] = '\0';
ifList[ifNum].port = -1;
ifNum++;
ifC = 0;
}
} else {
ifList[ifNum].prefix[ifC] = c;
ifC++;
}
ptr++;
} while (ifNum < maxList && c);
return ifNum;
}
static bool matchIf(const char* string, const char* ref, bool matchExact) {
// Make sure to include '\0' in the exact case
int matchLen = matchExact ? strlen(string) + 1 : strlen(ref);
return strncmp(string, ref, matchLen) == 0;
}
static bool matchPort(const int port1, const int port2) {
if (port1 == -1) return true;
if (port2 == -1) return true;
if (port1 == port2) return true;
return false;
}
bool matchIfList(const char* string, int port, struct netIf* ifList, int listSize, bool matchExact) {
// Make an exception for the case where no user list is defined
if (listSize == 0) return true;
for (int i = 0; i < listSize; i++) {
if (matchIf(string, ifList[i].prefix, matchExact) && matchPort(port, ifList[i].port)) {
return true;
}
}
return false;
}
/* Common socket address storage structure for IPv4/IPv6 */
union socketAddress {
struct sockaddr sa;
struct sockaddr_in sin;
struct sockaddr_in6 sin6;
};
/* Format a string representation of a (struct sockaddr *) socket address using getnameinfo()
*
* Output: "IPv4/IPv6 address<port>"
*/
static inline const char* socketToString(struct sockaddr* saddr, char* buf) {
if (buf == NULL || saddr == NULL) return NULL;
if (saddr->sa_family != AF_INET && saddr->sa_family != AF_INET6) {
buf[0] = '\0';
return buf;
}
char host[NI_MAXHOST], service[NI_MAXSERV];
(void)getnameinfo(saddr, sizeof(union socketAddress), host, NI_MAXHOST, service, NI_MAXSERV, NI_NUMERICHOST | NI_NUMERICSERV);
sprintf(buf, "%s<%s>", host, service);
return buf;
}
static inline uint16_t socketToPort(struct sockaddr* saddr) {
return ntohs(saddr->sa_family == AF_INET ? ((struct sockaddr_in*)saddr)->sin_port : ((struct sockaddr_in6*)saddr)->sin6_port);
}
/* Allow the user to force the IPv4/IPv6 interface selection */
static inline int envSocketFamily(void) {
int family = -1; // Family selection is not forced, will use first one found
char* env = getenv("NCCL_SOCKET_FAMILY");
if (env == NULL) return family;
INFO(NCCL_ENV, "NCCL_SOCKET_FAMILY set by environment to %s", env);
if (strcmp(env, "AF_INET") == 0)
family = AF_INET; // IPv4
else if (strcmp(env, "AF_INET6") == 0)
family = AF_INET6; // IPv6
return family;
}
static int findInterfaces(const char* prefixList, char* names, union socketAddress* addrs, int sock_family, int maxIfNameSize, int maxIfs) {
#ifdef ENABLE_TRACE
char line[1024];
#endif
struct netIf userIfs[MAX_IFS];
bool searchNot = prefixList && prefixList[0] == '^';
if (searchNot) prefixList++;
bool searchExact = prefixList && prefixList[0] == '=';
if (searchExact) prefixList++;
int nUserIfs = parseStringList(prefixList, userIfs, MAX_IFS);
int found = 0;
struct ifaddrs *interfaces, *interface;
getifaddrs(&interfaces);
for (interface = interfaces; interface && found < maxIfs; interface = interface->ifa_next) {
if (interface->ifa_addr == NULL) continue;
/* We only support IPv4 & IPv6 */
int family = interface->ifa_addr->sa_family;
if (family != AF_INET && family != AF_INET6) continue;
TRACE(NCCL_INIT | NCCL_NET, "Found interface %s:%s", interface->ifa_name, socketToString(interface->ifa_addr, line));
/* Allow the caller to force the socket family type */
if (sock_family != -1 && family != sock_family) continue;
/* We also need to skip IPv6 loopback interfaces */
if (family == AF_INET6) {
struct sockaddr_in6* sa = (struct sockaddr_in6*)(interface->ifa_addr);
if (IN6_IS_ADDR_LOOPBACK(&sa->sin6_addr)) continue;
}
// check against user specified interfaces
if (!(matchIfList(interface->ifa_name, -1, userIfs, nUserIfs, searchExact) ^ searchNot)) {
continue;
}
// Check that this interface has not already been saved
// getifaddrs() normal order appears to be; IPv4, IPv6 Global, IPv6 Link
bool duplicate = false;
for (int i = 0; i < found; i++) {
if (strcmp(interface->ifa_name, names + i * maxIfNameSize) == 0) {
duplicate = true;
break;
}
}
if (!duplicate) {
// Store the interface name
strncpy(names + found * maxIfNameSize, interface->ifa_name, maxIfNameSize);
// Store the IP address
int salen = (family == AF_INET) ? sizeof(sockaddr_in) : sizeof(sockaddr_in6);
memcpy(addrs + found, interface->ifa_addr, salen);
found++;
}
}
freeifaddrs(interfaces);
return found;
}
static bool matchSubnet(struct ifaddrs local_if, union socketAddress* remote) {
/* Check family first */
int family = local_if.ifa_addr->sa_family;
if (family != remote->sa.sa_family) {
return false;
}
if (family == AF_INET) {
struct sockaddr_in* local_addr = (struct sockaddr_in*)(local_if.ifa_addr);
struct sockaddr_in* mask = (struct sockaddr_in*)(local_if.ifa_netmask);
struct sockaddr_in& remote_addr = remote->sin;
struct in_addr local_subnet, remote_subnet;
local_subnet.s_addr = local_addr->sin_addr.s_addr & mask->sin_addr.s_addr;
remote_subnet.s_addr = remote_addr.sin_addr.s_addr & mask->sin_addr.s_addr;
return (local_subnet.s_addr ^ remote_subnet.s_addr) ? false : true;
} else if (family == AF_INET6) {
struct sockaddr_in6* local_addr = (struct sockaddr_in6*)(local_if.ifa_addr);
struct sockaddr_in6* mask = (struct sockaddr_in6*)(local_if.ifa_netmask);
struct sockaddr_in6& remote_addr = remote->sin6;
struct in6_addr& local_in6 = local_addr->sin6_addr;
struct in6_addr& mask_in6 = mask->sin6_addr;
struct in6_addr& remote_in6 = remote_addr.sin6_addr;
bool same = true;
int len = 16; // IPv6 address is 16 unsigned char
for (int c = 0; c < len; c++) { // Network byte order is big-endian
char c1 = local_in6.s6_addr[c] & mask_in6.s6_addr[c];
char c2 = remote_in6.s6_addr[c] & mask_in6.s6_addr[c];
if (c1 ^ c2) {
same = false;
break;
}
}
// At last, we need to compare scope id
// Two Link-type addresses can have the same subnet address even though they are not in the same scope
// For Global type, this field is 0, so a comparison wouldn't matter
same &= (local_addr->sin6_scope_id == remote_addr.sin6_scope_id);
return same;
} else {
WARN("Net : Unsupported address family type");
return false;
}
}
static int findInterfaceMatchSubnet(char* ifNames, union socketAddress* localAddrs, union socketAddress* remoteAddr, int ifNameMaxSize,
int maxIfs) {
#ifdef ENABLE_TRACE
char line[1024];
#endif
char line_a[1024];
int found = 0;
struct ifaddrs *interfaces, *interface;
getifaddrs(&interfaces);
for (interface = interfaces; interface && !found; interface = interface->ifa_next) {
if (interface->ifa_addr == NULL) continue;
/* We only support IPv4 & IPv6 */
int family = interface->ifa_addr->sa_family;
if (family != AF_INET && family != AF_INET6) continue;
// check against user specified interfaces
if (!matchSubnet(*interface, remoteAddr)) {
continue;
}
// Store the local IP address
int salen = (family == AF_INET) ? sizeof(sockaddr_in) : sizeof(sockaddr_in6);
memcpy(localAddrs + found, interface->ifa_addr, salen);
// Store the interface name
strncpy(ifNames + found * ifNameMaxSize, interface->ifa_name, ifNameMaxSize);
TRACE(NCCL_INIT | NCCL_NET, "NET : Found interface %s:%s in the same subnet as remote address %s", interface->ifa_name,
socketToString(&(localAddrs[found].sa), line), socketToString(&(remoteAddr->sa), line_a));
found++;
if (found == maxIfs) break;
}
if (found == 0) {
WARN("Net : No interface found in the same subnet as remote address %s", socketToString(&(remoteAddr->sa), line_a));
}
freeifaddrs(interfaces);
return found;
}
static ncclResult_t GetSocketAddrFromString(union socketAddress* ua, const char* ip_port_pair) {
if (!(ip_port_pair && strlen(ip_port_pair) > 1)) {
WARN("Net : string is null");
return ncclInvalidArgument;
}
bool ipv6 = ip_port_pair[0] == '[';
/* Construct the sockaddress structure */
if (!ipv6) {
struct netIf ni;
// parse <ip_or_hostname>:<port> string, expect one pair
if (parseStringList(ip_port_pair, &ni, 1) != 1) {
WARN("Net : No valid <IPv4_or_hostname>:<port> pair found");
return ncclInvalidArgument;
}
struct addrinfo hints, *p;
int rv;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
if ((rv = getaddrinfo(ni.prefix, NULL, &hints, &p)) != 0) {
WARN("Net : error encountered when getting address info : %s", gai_strerror(rv));
return ncclInvalidArgument;
}
// use the first
if (p->ai_family == AF_INET) {
struct sockaddr_in& sin = ua->sin;
memcpy(&sin, p->ai_addr, sizeof(struct sockaddr_in));
sin.sin_family = AF_INET; // IPv4
// inet_pton(AF_INET, ni.prefix, &(sin.sin_addr)); // IP address
sin.sin_port = htons(ni.port); // port
} else if (p->ai_family == AF_INET6) {
struct sockaddr_in6& sin6 = ua->sin6;
memcpy(&sin6, p->ai_addr, sizeof(struct sockaddr_in6));
sin6.sin6_family = AF_INET6; // IPv6
sin6.sin6_port = htons(ni.port); // port
sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete
sin6.sin6_scope_id = 0; // should be global scope, set to 0
} else {
WARN("Net : unsupported IP family");
return ncclInvalidArgument;
}
freeaddrinfo(p); // all done with this structure
} else {
int i, j = -1, len = strlen(ip_port_pair);
for (i = 1; i < len; i++) {
if (ip_port_pair[i] == '%') j = i;
if (ip_port_pair[i] == ']') break;
}
if (i == len) {
WARN("Net : No valid [IPv6]:port pair found");
return ncclInvalidArgument;
}
bool global_scope = (j == -1 ? true : false); // If no % found, global scope; otherwise, link scope
char ip_str[NI_MAXHOST], port_str[NI_MAXSERV], if_name[IFNAMSIZ];
memset(ip_str, '\0', sizeof(ip_str));
memset(port_str, '\0', sizeof(port_str));
memset(if_name, '\0', sizeof(if_name));
strncpy(ip_str, ip_port_pair + 1, global_scope ? i - 1 : j - 1);
strncpy(port_str, ip_port_pair + i + 2, len - i - 1);
int port = atoi(port_str);
if (!global_scope) strncpy(if_name, ip_port_pair + j + 1, i - j - 1); // If not global scope, we need the intf name
struct sockaddr_in6& sin6 = ua->sin6;
sin6.sin6_family = AF_INET6; // IPv6
inet_pton(AF_INET6, ip_str, &(sin6.sin6_addr)); // IP address
sin6.sin6_port = htons(port); // port
sin6.sin6_flowinfo = 0; // needed by IPv6, but possibly obsolete
sin6.sin6_scope_id = global_scope ? 0 : if_nametoindex(if_name); // 0 if global scope; intf index if link scope
}
return ncclSuccess;
}
static int findInterfaces(char* ifNames, union socketAddress* ifAddrs, int ifNameMaxSize, int maxIfs) {
static int shownIfName = 0;
int nIfs = 0;
// Allow user to force the INET socket family selection
int sock_family = envSocketFamily();
// User specified interface
char* env = getenv("NCCL_SOCKET_IFNAME");
if (env && strlen(env) > 1) {
INFO(NCCL_ENV, "NCCL_SOCKET_IFNAME set by environment to %s", env);
// Specified by user : find or fail
if (shownIfName++ == 0) INFO(NCCL_NET, "NCCL_SOCKET_IFNAME set to %s", env);
nIfs = findInterfaces(env, ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
} else {
// Try to automatically pick the right one
// Start with IB
nIfs = findInterfaces("ib", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
// else see if we can get some hint from COMM ID
if (nIfs == 0) {
char* commId = getenv("NCCL_COMM_ID");
if (commId && strlen(commId) > 1) {
INFO(NCCL_ENV, "NCCL_COMM_ID set by environment to %s", commId);
// Try to find interface that is in the same subnet as the IP in comm id
union socketAddress idAddr;
GetSocketAddrFromString(&idAddr, commId);
nIfs = findInterfaceMatchSubnet(ifNames, ifAddrs, &idAddr, ifNameMaxSize, maxIfs);
}
}
// Then look for anything else (but not docker or lo)
if (nIfs == 0) nIfs = findInterfaces("^docker,lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
// Finally look for docker, then lo.
if (nIfs == 0) nIfs = findInterfaces("docker", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
if (nIfs == 0) nIfs = findInterfaces("lo", ifNames, ifAddrs, sock_family, ifNameMaxSize, maxIfs);
}
return nIfs;
}
static ncclResult_t createListenSocket(int* fd, union socketAddress* localAddr) {
/* IPv4/IPv6 support */
int family = localAddr->sa.sa_family;
int salen = (family == AF_INET) ? sizeof(sockaddr_in) : sizeof(sockaddr_in6);
/* Create socket and bind it to a port */
int sockfd = socket(family, SOCK_STREAM, 0);
if (sockfd == -1) {
WARN("Net : Socket creation failed : %s", strerror(errno));
return ncclSystemError;
}
if (socketToPort(&localAddr->sa)) {
// Port is forced by env. Make sure we get the port.
int opt = 1;
#if defined(SO_REUSEPORT)
SYSCHECK(setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt");
#else
SYSCHECK(setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)), "setsockopt");
#endif
}
// localAddr port should be 0 (Any port)
SYSCHECK(bind(sockfd, &localAddr->sa, salen), "bind");
/* Get the assigned Port */
socklen_t size = salen;
SYSCHECK(getsockname(sockfd, &localAddr->sa, &size), "getsockname");
#ifdef ENABLE_TRACE
char line[1024];
TRACE(NCCL_INIT | NCCL_NET, "Listening on socket %s", socketToString(&localAddr->sa, line));
#endif
/* Put the socket in listen mode
* NB: The backlog will be silently truncated to the value in /proc/sys/net/core/somaxconn
*/
SYSCHECK(listen(sockfd, 16384), "listen");
*fd = sockfd;
return ncclSuccess;
}
static ncclResult_t connectAddress(int* fd, union socketAddress* remoteAddr) {
/* IPv4/IPv6 support */
int family = remoteAddr->sa.sa_family;
int salen = (family == AF_INET) ? sizeof(sockaddr_in) : sizeof(sockaddr_in6);
/* Connect to a hostname / port */
*fd = socket(family, SOCK_STREAM, 0);
if (*fd == -1) {
WARN("Net : Socket creation failed : %s", strerror(errno));
return ncclSystemError;
}
const int one = 1;
SYSCHECK(setsockopt(*fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt");
/* const int bufsize = 128*1024;
SYSCHECK(setsockopt(*fd, SOL_SOCKET, SO_SNDBUF, (char*)&bufsize, sizeof(int)), "setsockopt");
SYSCHECK(setsockopt(*fd, SOL_SOCKET, SO_RCVBUF, (char*)&bufsize, sizeof(int)), "setsockopt");*/
char line[1024];
#ifdef ENABLE_TRACE
TRACE(NCCL_INIT | NCCL_NET, "Connecting to socket %s", socketToString(&remoteAddr->sa, line));
#endif
int ret;
int timedout_retries = 0;
int refused_retries = 0;
retry:
SYSCHECKSYNC(connect(*fd, &remoteAddr->sa, salen), "connect", ret);
if (ret == 0) return ncclSuccess;
if ((errno == ECONNREFUSED || errno == ETIMEDOUT)) {
if ((errno == ECONNREFUSED && ++refused_retries < RETRY_REFUSED_TIMES) ||
(errno == ETIMEDOUT && ++timedout_retries < RETRY_TIMEDOUT_TIMES)) {
if (refused_retries % 1000 == 0) INFO(NCCL_ALL, "Call to connect returned %s, retrying", strerror(errno));
usleep(SLEEP_INT);
goto retry;
}
}
WARN("Connect to %s failed : %s", socketToString(&remoteAddr->sa, line), strerror(errno));
return ncclSystemError;
}
#define NCCL_SOCKET_SEND 0
#define NCCL_SOCKET_RECV 1
static ncclResult_t socketProgressOpt(int op, int fd, void* ptr, int size, int* offset, int block) {
int bytes = 0;
char* data = (char*)ptr;
do {
if (op == NCCL_SOCKET_RECV) bytes = recv(fd, data + (*offset), size - (*offset), block ? 0 : MSG_DONTWAIT);
if (op == NCCL_SOCKET_SEND) bytes = send(fd, data + (*offset), size - (*offset), block ? 0 : MSG_DONTWAIT);
if (op == NCCL_SOCKET_RECV && bytes == 0) {
WARN("Net : Connection closed by remote peer");
return ncclSystemError;
}
if (bytes == -1) {
if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) {
WARN("Call to recv failed : %s", strerror(errno));
return ncclSystemError;
} else {
bytes = 0;
}
}
(*offset) += bytes;
} while (bytes > 0 && (*offset) < size);
return ncclSuccess;
}
static ncclResult_t socketProgress(int op, int fd, void* ptr, int size, int* offset) {
return socketProgressOpt(op, fd, ptr, size, offset, 0);
}
static ncclResult_t socketWait(int op, int fd, void* ptr, int size, int* offset) {
while (*offset < size) NCCLCHECK(socketProgressOpt(op, fd, ptr, size, offset, 1));
return ncclSuccess;
}
static ncclResult_t socketSend(int fd, void* ptr, int size) {
int offset = 0;
NCCLCHECK(socketWait(NCCL_SOCKET_SEND, fd, ptr, size, &offset));
return ncclSuccess;
}
static ncclResult_t socketReceive(int fd, void* ptr, int size) {
int offset = 0;
NCCLCHECK(socketWait(NCCL_SOCKET_RECV, fd, ptr, size, &offset));
return ncclSuccess;
}
int main() {
char names[MAX_IF_NAME_SIZE*MAX_IFS];
union socketAddress addrs[MAX_IFS];
int ncclNetIfs = findInterfaces(names, addrs, MAX_IF_NAME_SIZE, MAX_IFS);
printf("\n======== Result ========\n");
printf("Found %d interface.\n", ncclNetIfs);
for(int i=0; i<ncclNetIfs; i++) {
printf(" - %d: %s\n", i, &names[MAX_IF_NAME_SIZE * i]);
}
}
@cloudhan
Copy link
Author

cloudhan commented Sep 4, 2020

If you want to know how the f*** NCCL pick network interface.

g++ nccl_socket.cpp -o nccl_socket

# run with with some additional envs, and see if the result match your expectation.
NCCL_COMM_ID=10.1.23.45 ./nccl_socket

available env:

NCCL_SOCKET_FAMILY
NCCL_SOCKET_IFNAME
NCCL_COMM_ID

NCCL_COMM_ID is missing from documentation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment