由SubprocVecEnv遇到的多进程经典问题

需要用到SubprocVecEnv并行采样,代码写着写着遇到多进程导致的问题,还是连着遇到三个,居然不能马上想出来,真是老了(

类变量计数

想要记录某个端口是否被占用,心路历程:想在环境类里定义一个类变量->多线程/进程可能会有并发问题->咋样处理并发->上锁?不对啊这地址空间都不一样->多进程同步太麻烦!->作罢

多进程时复制

SubprocVecEnv应该传一个List[Callable],但是误把环境类实例丢了进去。遇到了错误:

1
2
3
4
5
6
7
8
9
Traceback (most recent call last):
File "main.py", line 23, in <module>
trainer.train()
File "base.py", line 45, in train
train_env = SubprocVecEnv([self.env_cls(i) for i in range(kira_conf.n_envs)])
......
File "/home/illya/miniconda3/envs/3d-o/lib/python3.11/socket.py", line 274, in __getstate__
raise TypeError(f"cannot pickle {self.__class__.__name__!r} object")
TypeError: cannot pickle 'socket' object

我还疑惑为什么在 pickle socket,忽然想到多进程会拷贝进程空间里的东西,显然是拷贝时存在一个 socket 连接并出错,于是顺利清醒(

多进程地址空间

有一个定义在模块里的 global setting。程序开始对其初始化(一些类成员变量仅在setup中定义),但是子进程中运行的环境类实例找不到 setting 中初始化的变量。

1
2
3
4
5
6
from config import kira_conf

if __name__ == "__main__":
kira_conf.setup(env_cls.__name__)

trainer.train() # MultiProcess inside

按理来说即使地址空间不同,按照正常的代码在文件上的执行过程,子进程也会收到一份初始化后的拷贝。

在子进程中将 conf 里的东西打印出来:

1
2
3
4
5
6
7
{'ROOT': PosixPath('config/parameters')}
{'ROOT': PosixPath('config/parameters')}
{'ROOT': PosixPath('config/parameters')}
{'ROOT': PosixPath('config/parameters')}

# Content shall exist
{'ROOT': PosixPath('config/parameters'), 'PARAMS_RL_FILE': PosixPath('config/parameters/ShortKick/param.yaml'), 'server_ip': 'localhost', 'server_port': 4100}

子进程看到的 conf 只有一个实例化时带着的ROOT。

于是我将 setup 挪到 main 函数外:

1
2
3
4
5
6
from config import kira_conf

kira_conf.setup(env_cls.__name__)
if __name__ == "__main__":

trainer.train() # MultiProcess

能够正确完成初始化。

猜想:multiprocess创建了“完全隔离”的空间

查了查文档,总结如下:

Python的multiprocessing模块提供了三种启动新进程的方法:forkforkserverspawn。这三种方法在处理全局变量时的行为有所不同。

  1. forkfork方法是UNIX系统上的标准进程创建方式。它创建的子进程是父进程的完整复制品,包括父进程的全局状态。因此,使用fork方法创建的子进程可以访问父进程中定义的全局变量。然而,由于子进程中的全局变量实际上是父进程中的全局变量的副本,因此,对全局变量的修改不会在父进程和子进程之间共享。
  2. forkserverspawn: 这两种方法在创建新进程时,会启动一个全新的Python解释器进程。这意味着新进程不会继承父进程的全局状态。因此,新进程不能访问在父进程中定义的全局变量。这对于保证父进程和子进程的隔离性很有帮助,但也使得数据共享变得更困难。

验证代码:

1
2
3
4
5
6
7
8
9
10
# config.py

class Config:
def __init__(self):
self.a = 1

def setup(self):
self.b = 2

config = Config()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# main.py
import os
import multiprocessing
from config import config

def worker(i):
from config import config
print(config)

if __name__ == '__main__':

print(config)
print("--------------------------------------")

multiprocessing.set_start_method('forkserver')
num_process = 4
processes = []

config.setup()

for i in range(num_process):
p = multiprocessing.Process(target=worker, args=(i,))
p.start()
processes.append(p)

for p in processes:
p.join()

输出如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# fork
<config.Config object at 0x7f56d2c3f1d0>
--------------------------------------
0: <config.Config object at 0x7f56d2c3f1d0>
1: <config.Config object at 0x7f56d2c3f1d0>
2: <config.Config object at 0x7f56d2c3f1d0>
3: <config.Config object at 0x7f56d2c3f1d0>

# forkserver
<config.Config object at 0x7f75c882f010>
--------------------------------------
1: <config.Config object at 0x7f5c8ba4ff90>
0: <config.Config object at 0x7f5c8ba4fed0>
2: <config.Config object at 0x7f5c8ba4ff10>
3: <config.Config object at 0x7f5c8ba4ff10>

# spawn
<config.Config object at 0x7f677a8ef050>
--------------------------------------
2: <config.Config object at 0x7f825c0ef0d0>
1: <config.Config object at 0x7fc2776b31d0>
0: <config.Config object at 0x7f4f8d88af10>
3: <config.Config object at 0x7fb512f370d0>

forkserver这个比较有趣,总是后面输出的两三个进程的 config 有相同的地址。当然如果是fork的话,就要涉及操作系统的什么写时复制了。

将 setup 放在 main 函数外并且打印 b

1
2
3
4
5
6
7
8
9
config.setup()

def worker(i):
from config import config
print(f'{i}: {config}')
print(f'{config.b}\n')

if __name__ == '__main__':
...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
<config.Config object at 0x7f01772833d0>
--------------------------------------
2
2: <config.Config object at 0x7fb9f5117e90>
2

0: <config.Config object at 0x7fb9f5117e10>
2

1: <config.Config object at 0x7fb9f5117e90>
2

3: <config.Config object at 0x7fb9f5117e90>
2

因此可以推测是 forkserver/spawn 重新导入了父进程的模块,并执行了一次 main.py,所以main函数外的config.setup()被执行。

来个小彩蛋,把worker改成:

1
2
3
def worker(i):
from config import config
print(__name__)
1
2
3
4
5
6
<config.Config object at 0x7f4280ffb3d0>
--------------------------------------
__mp_main__
__mp_main__
__mp_main__
__mp_main__

总结

贴一份SubProcVecEnv的注释:

warning:

Only ‘forkserver’ and ‘spawn’ start methods are thread-safe, which is important when TensorFlow sessions or other non thread-safe libraries are used in the parent (see issue #217). However, compared to ‘fork’ they incur a small start-up cost and have restrictions on global variables. With those methods, users must wrap the code in an if __name__ == "__main__": block.

For more information, see the multiprocessing documentation.