Skip to content

Commit 3b21bc9

Browse files
Federico GiovanardiJens-G
authored andcommitted
Support socket activation by fd passing
Client: cpp Patch: Federico Giovanardi This closes #3211
1 parent 06bc195 commit 3b21bc9

File tree

6 files changed

+147
-32
lines changed

6 files changed

+147
-32
lines changed

lib/cpp/src/thrift/transport/TServerSocket.cpp

Lines changed: 57 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ TServerSocket::TServerSocket(int port)
117117
listening_(false),
118118
interruptSockWriter_(THRIFT_INVALID_SOCKET),
119119
interruptSockReader_(THRIFT_INVALID_SOCKET),
120-
childInterruptSockWriter_(THRIFT_INVALID_SOCKET) {
120+
childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
121+
boundSocketType_(SocketType::NONE) {
121122
}
122123

123124
TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout)
@@ -136,7 +137,8 @@ TServerSocket::TServerSocket(int port, int sendTimeout, int recvTimeout)
136137
listening_(false),
137138
interruptSockWriter_(THRIFT_INVALID_SOCKET),
138139
interruptSockReader_(THRIFT_INVALID_SOCKET),
139-
childInterruptSockWriter_(THRIFT_INVALID_SOCKET) {
140+
childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
141+
boundSocketType_(SocketType::NONE) {
140142
}
141143

142144
TServerSocket::TServerSocket(const string& address, int port)
@@ -156,7 +158,8 @@ TServerSocket::TServerSocket(const string& address, int port)
156158
listening_(false),
157159
interruptSockWriter_(THRIFT_INVALID_SOCKET),
158160
interruptSockReader_(THRIFT_INVALID_SOCKET),
159-
childInterruptSockWriter_(THRIFT_INVALID_SOCKET) {
161+
childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
162+
boundSocketType_(SocketType::NONE) {
160163
}
161164

162165
TServerSocket::TServerSocket(const string& path)
@@ -176,7 +179,28 @@ TServerSocket::TServerSocket(const string& path)
176179
listening_(false),
177180
interruptSockWriter_(THRIFT_INVALID_SOCKET),
178181
interruptSockReader_(THRIFT_INVALID_SOCKET),
179-
childInterruptSockWriter_(THRIFT_INVALID_SOCKET) {
182+
childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
183+
boundSocketType_(SocketType::NONE) {
184+
}
185+
TServerSocket::TServerSocket(THRIFT_SOCKET sock,SocketType socketType)
186+
: interruptableChildren_(true),
187+
port_(0),
188+
path_(),
189+
serverSocket_(sock),
190+
acceptBacklog_(DEFAULT_BACKLOG),
191+
sendTimeout_(0),
192+
recvTimeout_(0),
193+
accTimeout_(-1),
194+
retryLimit_(0),
195+
retryDelay_(0),
196+
tcpSendBuffer_(0),
197+
tcpRecvBuffer_(0),
198+
keepAlive_(false),
199+
listening_(false),
200+
interruptSockWriter_(THRIFT_INVALID_SOCKET),
201+
interruptSockReader_(THRIFT_INVALID_SOCKET),
202+
childInterruptSockWriter_(THRIFT_INVALID_SOCKET),
203+
boundSocketType_(socketType) {
180204
}
181205

182206
TServerSocket::~TServerSocket() {
@@ -439,7 +463,8 @@ void TServerSocket::listen() {
439463
if (isUnixDomainSocket()) {
440464
// -- Unix Domain Socket -- //
441465

442-
serverSocket_ = socket(PF_UNIX, SOCK_STREAM, IPPROTO_IP);
466+
if (serverSocket_ == THRIFT_INVALID_SOCKET)
467+
serverSocket_ = socket(PF_UNIX, SOCK_STREAM, IPPROTO_IP);
443468

444469
if (serverSocket_ == THRIFT_INVALID_SOCKET) {
445470
int errno_copy = THRIFT_GET_SOCKET_ERROR;
@@ -471,6 +496,8 @@ void TServerSocket::listen() {
471496
throw TTransportException(TTransportException::NOT_OPEN,
472497
" Unix Domain socket path not supported");
473498
#endif
499+
} else if( boundSocketType_ != SocketType::NONE){
500+
// -- Socket is already bound
474501
} else {
475502
// -- TCP socket -- //
476503

@@ -516,25 +543,31 @@ void TServerSocket::listen() {
516543
// use short circuit evaluation here to only sleep if we need to
517544
} while ((retries++ < retryLimit_) && (THRIFT_SLEEP_SEC(retryDelay_) == 0));
518545

519-
// retrieve bind info
520-
if (port_ == 0 && retries <= retryLimit_) {
521-
struct sockaddr_storage sa;
522-
socklen_t len = sizeof(sa);
523-
std::memset(&sa, 0, len);
524-
if (::getsockname(serverSocket_, reinterpret_cast<struct sockaddr*>(&sa), &len) < 0) {
525-
errno_copy = THRIFT_GET_SOCKET_ERROR;
526-
GlobalOutput.perror("TServerSocket::getPort() getsockname() ", errno_copy);
546+
} // TCP socket //
547+
548+
// retrieve bind info
549+
if ((port_ == 0 || path_.empty() ) && retries <= retryLimit_) {
550+
struct sockaddr_storage sa;
551+
socklen_t len = sizeof(sa);
552+
std::memset(&sa, 0, len);
553+
if (::getsockname(serverSocket_, reinterpret_cast<struct sockaddr*>(&sa), &len) < 0) {
554+
errno_copy = THRIFT_GET_SOCKET_ERROR;
555+
GlobalOutput.perror("TServerSocket::getPort() getsockname() ", errno_copy);
556+
} else {
557+
if (sa.ss_family == AF_INET6) {
558+
const auto* sin = reinterpret_cast<const struct sockaddr_in6*>(&sa);
559+
port_ = ntohs(sin->sin6_port);
560+
} else if (sa.ss_family == AF_INET) {
561+
const auto* sin = reinterpret_cast<const struct sockaddr_in*>(&sa);
562+
port_ = ntohs(sin->sin_port);
563+
} else if (sa.ss_family == AF_UNIX) {
564+
const auto* sin = reinterpret_cast<const struct sockaddr_un*>(&sa);
565+
path_ = sin->sun_path;
527566
} else {
528-
if (sa.ss_family == AF_INET6) {
529-
const auto* sin = reinterpret_cast<const struct sockaddr_in6*>(&sa);
530-
port_ = ntohs(sin->sin6_port);
531-
} else {
532-
const auto* sin = reinterpret_cast<const struct sockaddr_in*>(&sa);
533-
port_ = ntohs(sin->sin_port);
534-
}
567+
GlobalOutput.perror("TServerSocket::getPort() getsockname() unhandled socket type",EINVAL);
535568
}
536569
}
537-
} // TCP socket //
570+
}
538571

539572
// throw error if socket still wasn't created successfully
540573
if (serverSocket_ == THRIFT_INVALID_SOCKET) {
@@ -569,7 +602,7 @@ void TServerSocket::listen() {
569602
listenCallback_(serverSocket_);
570603

571604
// Call listen
572-
if (-1 == ::listen(serverSocket_, acceptBacklog_)) {
605+
if (boundSocketType_ == SocketType::NONE && -1 == ::listen(serverSocket_, acceptBacklog_)) {
573606
errno_copy = THRIFT_GET_SOCKET_ERROR;
574607
GlobalOutput.perror("TServerSocket::listen() listen() ", errno_copy);
575608
close();
@@ -734,7 +767,8 @@ void TServerSocket::close() {
734767
concurrency::Guard g(rwMutex_);
735768
if (serverSocket_ != THRIFT_INVALID_SOCKET) {
736769
shutdown(serverSocket_, THRIFT_SHUT_RDWR);
737-
::THRIFT_CLOSESOCKET(serverSocket_);
770+
if( boundSocketType_ == SocketType::NONE) //Do not close the server socket if it owned by systemd
771+
::THRIFT_CLOSESOCKET(serverSocket_);
738772
}
739773
if (interruptSockWriter_ != THRIFT_INVALID_SOCKET) {
740774
::THRIFT_CLOSESOCKET(interruptSockWriter_);

lib/cpp/src/thrift/transport/TServerSocket.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ namespace transport {
4040

4141
class TSocket;
4242

43+
enum class SocketType {
44+
NONE,
45+
INET,
46+
INET6,
47+
UNIX
48+
};
49+
4350
/**
4451
* Server socket implementation of TServerTransport. Wrapper around a unix
4552
* socket listen and accept calls.
@@ -82,6 +89,14 @@ class TServerSocket : public TServerTransport {
8289
*/
8390
TServerSocket(const std::string& path);
8491

92+
/**
93+
* Constructor used for to initialize from an already bound unix socket.
94+
* Useful for socket activation on systemd.
95+
*
96+
* @param fd
97+
*/
98+
TServerSocket(THRIFT_SOCKET sock,SocketType socketType);
99+
85100
~TServerSocket() override;
86101

87102

@@ -172,6 +187,7 @@ class TServerSocket : public TServerTransport {
172187

173188
socket_func_t listenCallback_;
174189
socket_func_t acceptCallback_;
190+
SocketType boundSocketType_;
175191
};
176192
}
177193
}

test/cpp/src/TestServer.cpp

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <thrift/server/TSimpleServer.h>
3232
#include <thrift/server/TThreadPoolServer.h>
3333
#include <thrift/server/TThreadedServer.h>
34+
#include <thrift/transport/PlatformSocket.h>
3435
#include <thrift/transport/THttpServer.h>
3536
#include <thrift/transport/THttpTransport.h>
3637
#include <thrift/transport/TNonblockingSSLServerSocket.h>
@@ -54,14 +55,21 @@
5455
#ifdef HAVE_SIGNAL_H
5556
#include <signal.h>
5657
#endif
58+
#ifdef HAVE_SYS_SOCKET_H
59+
#include <sys/socket.h>
60+
#endif
61+
#ifdef HAVE_SYS_UN_H
62+
#include <sys/un.h>
63+
#endif
5764

5865
#include <iostream>
59-
#include <stdexcept>
66+
#include <memory>
6067
#include <sstream>
68+
#include <stdexcept>
6169

6270
#include <boost/algorithm/string.hpp>
63-
#include <boost/program_options.hpp>
6471
#include <boost/filesystem.hpp>
72+
#include <boost/program_options.hpp>
6573

6674
#if _WIN32
6775
#include <thrift/windows/TWinsockSingleton.h>
@@ -570,6 +578,47 @@ class TestHandlerAsync : public ThriftTestCobSvIf {
570578
std::shared_ptr<TestHandler> _delegate;
571579
};
572580

581+
struct DomainSocketFd {
582+
THRIFT_SOCKET socket_fd;
583+
std::string path;
584+
DomainSocketFd(const std::string& path) : path(path) {
585+
#ifdef HAVE_SYS_UN_H
586+
unlink(path.c_str());
587+
socket_fd = socket(AF_UNIX, SOCK_STREAM, IPPROTO_IP);
588+
if (socket_fd == -1) {
589+
std::ostringstream os;
590+
os << "Cannot create domain socket: " << strerror(errno);
591+
throw std::runtime_error(os.str());
592+
}
593+
if (path.size() > sizeof(sockaddr_un::sun_path) - 1)
594+
throw std::runtime_error("Path size on domain socket too big");
595+
struct sockaddr_un sa;
596+
memset(&sa, 0, sizeof(sa));
597+
sa.sun_family = AF_UNIX;
598+
strcpy(sa.sun_path, path.c_str());
599+
int rv = bind(socket_fd, (struct sockaddr*)&sa, sizeof(sa));
600+
if (rv == -1) {
601+
std::ostringstream os;
602+
os << "Cannot bind domain socket: " << strerror(errno);
603+
throw std::runtime_error(os.str());
604+
}
605+
606+
rv = ::listen(socket_fd, 16);
607+
if (rv == -1) {
608+
std::ostringstream os;
609+
os << "Cannot listen on domain socket: " << strerror(errno);
610+
throw std::runtime_error(os.str());
611+
}
612+
#else
613+
throw std::runtime_error("Cannot create a domain socket without AF_UNIX");
614+
#endif
615+
}
616+
~DomainSocketFd() {
617+
::THRIFT_CLOSESOCKET(socket_fd);
618+
unlink(path.c_str());
619+
}
620+
};
621+
573622
namespace po = boost::program_options;
574623

575624
int main(int argc, char** argv) {
@@ -589,6 +638,8 @@ int main(int argc, char** argv) {
589638
string server_type = "simple";
590639
string domain_socket = "";
591640
bool abstract_namespace = false;
641+
bool emulate_socketactivation = false;
642+
std::unique_ptr<DomainSocketFd> domain_socket_fd;
592643
size_t workers = 4;
593644
int string_limit = 0;
594645
int container_limit = 0;
@@ -599,6 +650,7 @@ int main(int argc, char** argv) {
599650
("port", po::value<int>(&port)->default_value(port), "Port number to listen")
600651
("domain-socket", po::value<string>(&domain_socket) ->default_value(domain_socket), "Unix Domain Socket (e.g. /tmp/ThriftTest.thrift)")
601652
("abstract-namespace", "Create the domain socket in the Abstract Namespace (no connection with filesystem pathnames)")
653+
("emulate-socketactivation","Open the socket from the tester program and pass the library an already open fd")
602654
("server-type", po::value<string>(&server_type)->default_value(server_type), "type of server, \"simple\", \"thread-pool\", \"threaded\", or \"nonblocking\"")
603655
("transport", po::value<string>(&transport_type)->default_value(transport_type), "transport: buffered, framed, http, websocket, zlib")
604656
("protocol", po::value<string>(&protocol_type)->default_value(protocol_type), "protocol: binary, compact, header, json, multi, multic, multih, multij")
@@ -678,6 +730,9 @@ int main(int argc, char** argv) {
678730
if (vm.count("abstract-namespace")) {
679731
abstract_namespace = true;
680732
}
733+
if (vm.count("emulate-socketactivation")) {
734+
emulate_socketactivation = true;
735+
}
681736

682737
// Dispatcher
683738
std::shared_ptr<TProtocolFactory> protocolFactory;
@@ -727,8 +782,16 @@ int main(int argc, char** argv) {
727782
abstract_socket += domain_socket;
728783
serverSocket = std::shared_ptr<TServerSocket>(new TServerSocket(abstract_socket));
729784
} else {
730-
unlink(domain_socket.c_str());
731-
serverSocket = std::shared_ptr<TServerSocket>(new TServerSocket(domain_socket));
785+
if (emulate_socketactivation) {
786+
unlink(domain_socket.c_str());
787+
// open and bind the socket
788+
domain_socket_fd.reset(new DomainSocketFd(domain_socket));
789+
serverSocket = std::shared_ptr<TServerSocket>(
790+
new TServerSocket(domain_socket_fd->socket_fd, SocketType::UNIX));
791+
} else {
792+
unlink(domain_socket.c_str());
793+
serverSocket = std::shared_ptr<TServerSocket>(new TServerSocket(domain_socket));
794+
}
732795
}
733796
port = 0;
734797
} else {

test/crossrunner/run.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def _get_domain_port(self):
306306
return port if ok else self._get_domain_port()
307307

308308
def alloc_port(self, socket_type):
309-
if socket_type in ('domain', 'abstract'):
309+
if socket_type in ('domain', 'abstract','domain-socketactivated'):
310310
return self._get_domain_port()
311311
else:
312312
return self._get_tcp_port()
@@ -323,7 +323,7 @@ def free_port(self, socket_type, port):
323323
self._log.debug('free_port')
324324
self._lock.acquire()
325325
try:
326-
if socket_type == 'domain':
326+
if socket_type in ['domain','domain-socketactivated']:
327327
self._dom_ports.remove(port)
328328
path = domain_socket_path(port)
329329
if os.path.exists(path):

test/crossrunner/test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,11 @@ def abs_if_exists(arg):
5959
return cmd
6060

6161
def _socket_args(self, socket, port):
62+
support_socket_activation = self.kind == 'server' and sys.platform != "win32"
6263
return {
6364
'ip-ssl': ['--ssl'],
6465
'domain': ['--domain-socket=%s' % domain_socket_path(port)],
66+
'domain-socketactivated': (['--emulate-socketactivation'] if support_socket_activation else []) + ['--domain-socket=%s' % domain_socket_path(port)],
6567
'abstract': ['--abstract-namespace', '--domain-socket=%s' % domain_socket_path(port)],
6668
}.get(socket, None)
6769

test/tests.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -404,13 +404,13 @@
404404
"buffered",
405405
"http",
406406
"framed",
407-
"zlib",
408-
"websocket"
407+
"zlib"
409408
],
410409
"sockets": [
411410
"ip",
412411
"ip-ssl",
413-
"domain"
412+
"domain",
413+
"domain-socketactivated"
414414
],
415415
"protocols": [
416416
"compact",

0 commit comments

Comments
 (0)