需要用到SubprocVecEnv
并行采样,代码写着写着遇到多进程导致的问题,还是连着遇到三个,居然不能马上想出来,真是老了(
类变量计数
想要记录某个端口是否被占用,心路历程:想在环境类里定义一个类变量->多线程/进程可能会有并发问题->咋样处理并发->上锁?不对啊这地址空间都不一样->多进程同步太麻烦!->作罢
多进程时复制
SubprocVecEnv应该传一个List[Callable]
,但是误把环境类实例丢了进去。遇到了错误:
1 2 3 4 5 6 7 8 9 Traceback (most recent call last): File "main.py", line 23, in <module> trainer.train() File "base.py", line 45, in train train_env = SubprocVecEnv([self.env_cls(i) for i in range(kira_conf.n_envs)]) ...... File "/home/illya/miniconda3/envs/3d-o/lib/python3.11/socket.py", line 274, in __getstate__ raise TypeError(f"cannot pickle {self.__class__.__name__!r} object") TypeError: cannot pickle 'socket' object
我还疑惑为什么在 pickle socket,忽然想到多进程会拷贝进程空间里的东西,显然是拷贝时存在一个 socket 连接并出错,于是顺利清醒(
多进程地址空间
有一个定义在模块里的 global setting。程序开始对其初始化(一些类成员变量仅在setup中定义),但是子进程中运行的环境类实例找不到 setting 中初始化的变量。
1 2 3 4 5 6 from config import kira_confif __name__ == "__main__" : kira_conf.setup(env_cls.__name__) trainer.train()
按理来说即使地址空间不同,按照正常的代码在文件上的执行过程,子进程也会收到一份初始化后的拷贝。
在子进程中将 conf 里的东西打印出来:
1 2 3 4 5 6 7 {'ROOT': PosixPath('config/parameters')} {'ROOT': PosixPath('config/parameters')} {'ROOT': PosixPath('config/parameters')} {'ROOT': PosixPath('config/parameters')} # Content shall exist {'ROOT': PosixPath('config/parameters'), 'PARAMS_RL_FILE': PosixPath('config/parameters/ShortKick/param.yaml'), 'server_ip': 'localhost', 'server_port': 4100}
子进程看到的 conf 只有一个实例化时带着的ROOT。
于是我将 setup 挪到 main 函数外:
1 2 3 4 5 6 from config import kira_confkira_conf.setup(env_cls.__name__) if __name__ == "__main__" : trainer.train()
能够正确完成初始化。
猜想:multiprocess创建了“完全隔离”的空间
查了查文档,总结如下:
Python的multiprocessing
模块提供了三种启动新进程的方法:fork
,forkserver
和spawn
。这三种方法在处理全局变量时的行为有所不同。
fork
: fork
方法是UNIX系统上的标准进程创建方式。它创建的子进程是父进程的完整复制品,包括父进程的全局状态。因此,使用fork
方法创建的子进程可以访问父进程中定义的全局变量。然而,由于子进程中的全局变量实际上是父进程中的全局变量的副本,因此,对全局变量的修改不会在父进程和子进程之间共享。
forkserver
和spawn
: 这两种方法在创建新进程时,会启动一个全新的Python解释器进程。这意味着新进程不会继承父进程的全局状态。因此,新进程不能访问在父进程中定义的全局变量。这对于保证父进程和子进程的隔离性很有帮助,但也使得数据共享变得更困难。
验证代码:
1 2 3 4 5 6 7 8 9 10 class Config : def __init__ (self ): self .a = 1 def setup (self ): self .b = 2 config = Config()
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 import osimport multiprocessingfrom config import configdef worker (i ): from config import config print (config) if __name__ == '__main__' : print (config) print ("--------------------------------------" ) multiprocessing.set_start_method('forkserver' ) num_process = 4 processes = [] config.setup() for i in range (num_process): p = multiprocessing.Process(target=worker, args=(i,)) p.start() processes.append(p) for p in processes: p.join()
输出如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 # fork <config.Config object at 0x7f56d2c3f1d0> -------------------------------------- 0: <config.Config object at 0x7f56d2c3f1d0> 1: <config.Config object at 0x7f56d2c3f1d0> 2: <config.Config object at 0x7f56d2c3f1d0> 3: <config.Config object at 0x7f56d2c3f1d0> # forkserver <config.Config object at 0x7f75c882f010> -------------------------------------- 1: <config.Config object at 0x7f5c8ba4ff90> 0: <config.Config object at 0x7f5c8ba4fed0> 2: <config.Config object at 0x7f5c8ba4ff10> 3: <config.Config object at 0x7f5c8ba4ff10> # spawn <config.Config object at 0x7f677a8ef050> -------------------------------------- 2: <config.Config object at 0x7f825c0ef0d0> 1: <config.Config object at 0x7fc2776b31d0> 0: <config.Config object at 0x7f4f8d88af10> 3: <config.Config object at 0x7fb512f370d0>
forkserver
这个比较有趣,总是后面输出的两三个进程的 config 有相同的地址。当然如果是fork
的话,就要涉及操作系统的什么写时复制了。
将 setup 放在 main 函数外并且打印 b
1 2 3 4 5 6 7 8 9 config.setup() def worker (i ): from config import config print (f'{i} : {config} ' ) print (f'{config.b} \n' ) if __name__ == '__main__' : ...
1 2 3 4 5 6 7 8 9 10 11 12 13 14 <config.Config object at 0x7f01772833d0> -------------------------------------- 2 2: <config.Config object at 0x7fb9f5117e90> 2 0: <config.Config object at 0x7fb9f5117e10> 2 1: <config.Config object at 0x7fb9f5117e90> 2 3: <config.Config object at 0x7fb9f5117e90> 2
因此可以推测是 forkserver/spawn 重新导入了父进程的模块,并执行了一次 main.py ,所以main函数外的config.setup()被执行。
来个小彩蛋,把worker改成:
1 2 3 def worker (i ): from config import config print (__name__)
1 2 3 4 5 6 <config.Config object at 0x7f4280ffb3d0> -------------------------------------- __mp_main__ __mp_main__ __mp_main__ __mp_main__
总结
贴一份SubProcVecEnv
的注释:
warning:
Only ‘forkserver’ and ‘spawn’ start methods are thread-safe, which is important when TensorFlow sessions or other non thread-safe libraries are used in the parent (see issue #217). However, compared to ‘fork’ they incur a small start-up cost and have restrictions on global variables. With those methods, users must wrap the code in an if __name__ == "__main__":
block.
For more information, see the multiprocessing documentation.