socket.h 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. #pragma once
  2. #include <poll.h>
  3. #include <sys/socket.h>
  4. #include <sys/stat.h>
  5. #include <sys/types.h>
  6. #include <sys/un.h>
  7. #include <unistd.h>
  8. #include <cstddef>
  9. #include <cstdio>
  10. #include <cstring>
  11. #include <string>
  12. #include <libshm/alloc_info.h>
  13. #include <libshm/err.h>
  14. class Socket {
  15. public:
  16. int socket_fd;
  17. Socket(const Socket& other) = delete;
  18. protected:
  19. Socket() {
  20. SYSCHECK_ERR_RETURN_NEG1(socket_fd = socket(AF_UNIX, SOCK_STREAM, 0));
  21. }
  22. Socket(Socket&& other) noexcept : socket_fd(other.socket_fd) {
  23. other.socket_fd = -1;
  24. };
  25. explicit Socket(int fd) : socket_fd(fd) {}
  26. virtual ~Socket() {
  27. if (socket_fd != -1)
  28. close(socket_fd);
  29. }
  30. struct sockaddr_un prepare_address(const char* path) {
  31. struct sockaddr_un address;
  32. address.sun_family = AF_UNIX;
  33. strcpy(address.sun_path, path);
  34. return address;
  35. }
  36. // Implemented based on https://man7.org/linux/man-pages/man7/unix.7.html
  37. size_t address_length(struct sockaddr_un address) {
  38. return offsetof(sockaddr_un, sun_path) + strlen(address.sun_path) + 1;
  39. }
  40. void recv(void* _buffer, size_t num_bytes) {
  41. char* buffer = (char*)_buffer;
  42. size_t bytes_received = 0;
  43. ssize_t step_received;
  44. struct pollfd pfd = {};
  45. pfd.fd = socket_fd;
  46. pfd.events = POLLIN;
  47. while (bytes_received < num_bytes) {
  48. SYSCHECK_ERR_RETURN_NEG1(poll(&pfd, 1, 1000));
  49. if (pfd.revents & POLLIN) {
  50. SYSCHECK_ERR_RETURN_NEG1(
  51. step_received =
  52. ::read(socket_fd, buffer, num_bytes - bytes_received));
  53. TORCH_CHECK(step_received != 0, "Other end has closed the connection");
  54. bytes_received += step_received;
  55. buffer += step_received;
  56. } else if (pfd.revents & (POLLERR | POLLHUP)) {
  57. TORCH_CHECK(false, "An error occurred while waiting for the data");
  58. } else {
  59. TORCH_CHECK(false, "Shared memory manager connection has timed out");
  60. }
  61. }
  62. }
  63. void send(const void* _buffer, size_t num_bytes) {
  64. const char* buffer = (const char*)_buffer;
  65. size_t bytes_sent = 0;
  66. ssize_t step_sent;
  67. while (bytes_sent < num_bytes) {
  68. SYSCHECK_ERR_RETURN_NEG1(
  69. step_sent = ::write(socket_fd, buffer, num_bytes));
  70. bytes_sent += step_sent;
  71. buffer += step_sent;
  72. }
  73. }
  74. };
  75. class ManagerSocket : public Socket {
  76. public:
  77. explicit ManagerSocket(int fd) : Socket(fd) {}
  78. AllocInfo receive() {
  79. AllocInfo info;
  80. recv(&info, sizeof(info));
  81. return info;
  82. }
  83. void confirm() {
  84. send("OK", 2);
  85. }
  86. };
  87. class ManagerServerSocket : public Socket {
  88. public:
  89. explicit ManagerServerSocket(const std::string& path) {
  90. socket_path = path;
  91. try {
  92. struct sockaddr_un address = prepare_address(path.c_str());
  93. size_t len = address_length(address);
  94. SYSCHECK_ERR_RETURN_NEG1(
  95. bind(socket_fd, (struct sockaddr*)&address, len));
  96. SYSCHECK_ERR_RETURN_NEG1(listen(socket_fd, 10));
  97. } catch (std::exception&) {
  98. SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
  99. throw;
  100. }
  101. }
  102. void remove() {
  103. struct stat file_stat;
  104. if (fstat(socket_fd, &file_stat) == 0)
  105. SYSCHECK_ERR_RETURN_NEG1(unlink(socket_path.c_str()));
  106. }
  107. ~ManagerServerSocket() override {
  108. unlink(socket_path.c_str());
  109. }
  110. ManagerSocket accept() {
  111. int client_fd;
  112. struct sockaddr_un addr;
  113. socklen_t addr_len = sizeof(addr);
  114. SYSCHECK_ERR_RETURN_NEG1(
  115. client_fd = ::accept(socket_fd, (struct sockaddr*)&addr, &addr_len));
  116. return ManagerSocket(client_fd);
  117. }
  118. std::string socket_path;
  119. };
  120. class ClientSocket : public Socket {
  121. public:
  122. explicit ClientSocket(const std::string& path) {
  123. try {
  124. struct sockaddr_un address = prepare_address(path.c_str());
  125. size_t len = address_length(address);
  126. SYSCHECK_ERR_RETURN_NEG1(
  127. connect(socket_fd, (struct sockaddr*)&address, len));
  128. } catch (std::exception&) {
  129. SYSCHECK_ERR_RETURN_NEG1(close(socket_fd));
  130. throw;
  131. }
  132. }
  133. void register_allocation(AllocInfo& info) {
  134. char buffer[3] = {0, 0, 0};
  135. send(&info, sizeof(info));
  136. recv(buffer, 2);
  137. TORCH_CHECK(
  138. strcmp(buffer, "OK") == 0,
  139. "Shared memory manager didn't respond with an OK");
  140. }
  141. void register_deallocation(AllocInfo& info) {
  142. send(&info, sizeof(info));
  143. }
  144. };