blob: 215ac39afcaa07988c907e5a60d4685b6da61512 [file] [log] [blame]
xf.li86118912025-03-19 20:07:27 -07001import errno
2import os
3import selectors
4import signal
5import socket
6import struct
7import sys
8import threading
9import warnings
10
11from . import connection
12from . import process
13from .context import reduction
14from . import resource_tracker
15from . import spawn
16from . import util
17
18__all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
19 'set_forkserver_preload']
20
21#
22#
23#
24
25MAXFDS_TO_SEND = 256
26SIGNED_STRUCT = struct.Struct('q') # large enough for pid_t
27
28#
29# Forkserver class
30#
31
32class ForkServer(object):
33
34 def __init__(self):
35 self._forkserver_address = None
36 self._forkserver_alive_fd = None
37 self._forkserver_pid = None
38 self._inherited_fds = None
39 self._lock = threading.Lock()
40 self._preload_modules = ['__main__']
41
42 def _stop(self):
43 # Method used by unit tests to stop the server
44 with self._lock:
45 self._stop_unlocked()
46
47 def _stop_unlocked(self):
48 if self._forkserver_pid is None:
49 return
50
51 # close the "alive" file descriptor asks the server to stop
52 os.close(self._forkserver_alive_fd)
53 self._forkserver_alive_fd = None
54
55 os.waitpid(self._forkserver_pid, 0)
56 self._forkserver_pid = None
57
58 if not util.is_abstract_socket_namespace(self._forkserver_address):
59 os.unlink(self._forkserver_address)
60 self._forkserver_address = None
61
62 def set_forkserver_preload(self, modules_names):
63 '''Set list of module names to try to load in forkserver process.'''
64 if not all(type(mod) is str for mod in self._preload_modules):
65 raise TypeError('module_names must be a list of strings')
66 self._preload_modules = modules_names
67
68 def get_inherited_fds(self):
69 '''Return list of fds inherited from parent process.
70
71 This returns None if the current process was not started by fork
72 server.
73 '''
74 return self._inherited_fds
75
76 def connect_to_new_process(self, fds):
77 '''Request forkserver to create a child process.
78
79 Returns a pair of fds (status_r, data_w). The calling process can read
80 the child process's pid and (eventually) its returncode from status_r.
81 The calling process should write to data_w the pickled preparation and
82 process data.
83 '''
84 self.ensure_running()
85 if len(fds) + 4 >= MAXFDS_TO_SEND:
86 raise ValueError('too many fds')
87 with socket.socket(socket.AF_UNIX) as client:
88 client.connect(self._forkserver_address)
89 parent_r, child_w = os.pipe()
90 child_r, parent_w = os.pipe()
91 allfds = [child_r, child_w, self._forkserver_alive_fd,
92 resource_tracker.getfd()]
93 allfds += fds
94 try:
95 reduction.sendfds(client, allfds)
96 return parent_r, parent_w
97 except:
98 os.close(parent_r)
99 os.close(parent_w)
100 raise
101 finally:
102 os.close(child_r)
103 os.close(child_w)
104
105 def ensure_running(self):
106 '''Make sure that a fork server is running.
107
108 This can be called from any process. Note that usually a child
109 process will just reuse the forkserver started by its parent, so
110 ensure_running() will do nothing.
111 '''
112 with self._lock:
113 resource_tracker.ensure_running()
114 if self._forkserver_pid is not None:
115 # forkserver was launched before, is it still running?
116 pid, status = os.waitpid(self._forkserver_pid, os.WNOHANG)
117 if not pid:
118 # still alive
119 return
120 # dead, launch it again
121 os.close(self._forkserver_alive_fd)
122 self._forkserver_address = None
123 self._forkserver_alive_fd = None
124 self._forkserver_pid = None
125
126 cmd = ('from multiprocessing.forkserver import main; ' +
127 'main(%d, %d, %r, **%r)')
128
129 if self._preload_modules:
130 desired_keys = {'main_path', 'sys_path'}
131 data = spawn.get_preparation_data('ignore')
132 data = {x: y for x, y in data.items() if x in desired_keys}
133 else:
134 data = {}
135
136 with socket.socket(socket.AF_UNIX) as listener:
137 address = connection.arbitrary_address('AF_UNIX')
138 listener.bind(address)
139 if not util.is_abstract_socket_namespace(address):
140 os.chmod(address, 0o600)
141 listener.listen()
142
143 # all client processes own the write end of the "alive" pipe;
144 # when they all terminate the read end becomes ready.
145 alive_r, alive_w = os.pipe()
146 try:
147 fds_to_pass = [listener.fileno(), alive_r]
148 cmd %= (listener.fileno(), alive_r, self._preload_modules,
149 data)
150 exe = spawn.get_executable()
151 args = [exe] + util._args_from_interpreter_flags()
152 args += ['-c', cmd]
153 pid = util.spawnv_passfds(exe, args, fds_to_pass)
154 except:
155 os.close(alive_w)
156 raise
157 finally:
158 os.close(alive_r)
159 self._forkserver_address = address
160 self._forkserver_alive_fd = alive_w
161 self._forkserver_pid = pid
162
163#
164#
165#
166
167def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
168 '''Run forkserver.'''
169 if preload:
170 if '__main__' in preload and main_path is not None:
171 process.current_process()._inheriting = True
172 try:
173 spawn.import_main_path(main_path)
174 finally:
175 del process.current_process()._inheriting
176 for modname in preload:
177 try:
178 __import__(modname)
179 except ImportError:
180 pass
181
182 util._close_stdin()
183
184 sig_r, sig_w = os.pipe()
185 os.set_blocking(sig_r, False)
186 os.set_blocking(sig_w, False)
187
188 def sigchld_handler(*_unused):
189 # Dummy signal handler, doesn't do anything
190 pass
191
192 handlers = {
193 # unblocking SIGCHLD allows the wakeup fd to notify our event loop
194 signal.SIGCHLD: sigchld_handler,
195 # protect the process from ^C
196 signal.SIGINT: signal.SIG_IGN,
197 }
198 old_handlers = {sig: signal.signal(sig, val)
199 for (sig, val) in handlers.items()}
200
201 # calling os.write() in the Python signal handler is racy
202 signal.set_wakeup_fd(sig_w)
203
204 # map child pids to client fds
205 pid_to_fd = {}
206
207 with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \
208 selectors.DefaultSelector() as selector:
209 _forkserver._forkserver_address = listener.getsockname()
210
211 selector.register(listener, selectors.EVENT_READ)
212 selector.register(alive_r, selectors.EVENT_READ)
213 selector.register(sig_r, selectors.EVENT_READ)
214
215 while True:
216 try:
217 while True:
218 rfds = [key.fileobj for (key, events) in selector.select()]
219 if rfds:
220 break
221
222 if alive_r in rfds:
223 # EOF because no more client processes left
224 assert os.read(alive_r, 1) == b'', "Not at EOF?"
225 raise SystemExit
226
227 if sig_r in rfds:
228 # Got SIGCHLD
229 os.read(sig_r, 65536) # exhaust
230 while True:
231 # Scan for child processes
232 try:
233 pid, sts = os.waitpid(-1, os.WNOHANG)
234 except ChildProcessError:
235 break
236 if pid == 0:
237 break
238 child_w = pid_to_fd.pop(pid, None)
239 if child_w is not None:
240 if os.WIFSIGNALED(sts):
241 returncode = -os.WTERMSIG(sts)
242 else:
243 if not os.WIFEXITED(sts):
244 raise AssertionError(
245 "Child {0:n} status is {1:n}".format(
246 pid,sts))
247 returncode = os.WEXITSTATUS(sts)
248 # Send exit code to client process
249 try:
250 write_signed(child_w, returncode)
251 except BrokenPipeError:
252 # client vanished
253 pass
254 os.close(child_w)
255 else:
256 # This shouldn't happen really
257 warnings.warn('forkserver: waitpid returned '
258 'unexpected pid %d' % pid)
259
260 if listener in rfds:
261 # Incoming fork request
262 with listener.accept()[0] as s:
263 # Receive fds from client
264 fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
265 if len(fds) > MAXFDS_TO_SEND:
266 raise RuntimeError(
267 "Too many ({0:n}) fds to send".format(
268 len(fds)))
269 child_r, child_w, *fds = fds
270 s.close()
271 pid = os.fork()
272 if pid == 0:
273 # Child
274 code = 1
275 try:
276 listener.close()
277 selector.close()
278 unused_fds = [alive_r, child_w, sig_r, sig_w]
279 unused_fds.extend(pid_to_fd.values())
280 code = _serve_one(child_r, fds,
281 unused_fds,
282 old_handlers)
283 except Exception:
284 sys.excepthook(*sys.exc_info())
285 sys.stderr.flush()
286 finally:
287 os._exit(code)
288 else:
289 # Send pid to client process
290 try:
291 write_signed(child_w, pid)
292 except BrokenPipeError:
293 # client vanished
294 pass
295 pid_to_fd[pid] = child_w
296 os.close(child_r)
297 for fd in fds:
298 os.close(fd)
299
300 except OSError as e:
301 if e.errno != errno.ECONNABORTED:
302 raise
303
304
305def _serve_one(child_r, fds, unused_fds, handlers):
306 # close unnecessary stuff and reset signal handlers
307 signal.set_wakeup_fd(-1)
308 for sig, val in handlers.items():
309 signal.signal(sig, val)
310 for fd in unused_fds:
311 os.close(fd)
312
313 (_forkserver._forkserver_alive_fd,
314 resource_tracker._resource_tracker._fd,
315 *_forkserver._inherited_fds) = fds
316
317 # Run process object received over pipe
318 parent_sentinel = os.dup(child_r)
319 code = spawn._main(child_r, parent_sentinel)
320
321 return code
322
323
324#
325# Read and write signed numbers
326#
327
328def read_signed(fd):
329 data = b''
330 length = SIGNED_STRUCT.size
331 while len(data) < length:
332 s = os.read(fd, length - len(data))
333 if not s:
334 raise EOFError('unexpected EOF')
335 data += s
336 return SIGNED_STRUCT.unpack(data)[0]
337
338def write_signed(fd, n):
339 msg = SIGNED_STRUCT.pack(n)
340 while msg:
341 nbytes = os.write(fd, msg)
342 if nbytes == 0:
343 raise RuntimeError('should not get here')
344 msg = msg[nbytes:]
345
346#
347#
348#
349
350_forkserver = ForkServer()
351ensure_running = _forkserver.ensure_running
352get_inherited_fds = _forkserver.get_inherited_fds
353connect_to_new_process = _forkserver.connect_to_new_process
354set_forkserver_preload = _forkserver.set_forkserver_preload