Force explicit allowing initialization in child process.

Change: in current implementation, pools created in child processes COULD NOT share data with main process. Thus an exception will be raised by default when attemp to create pool in a child process, which can be suppressed with `allow_child_init` parameter.
This commit is contained in:
2026-02-21 12:42:37 +08:00
parent cd335c1b3f
commit 6bb8873b48
2 changed files with 167 additions and 2 deletions

View File

@ -52,15 +52,26 @@ class MultiProcessingSharedPool:
The singleton instance. The singleton instance.
_lock : threading.Lock _lock : threading.Lock
Lock for thread-safe singleton instantiation. Lock for thread-safe singleton instantiation.
_allow_child_init : bool
Whether to allow initialization in child processes.
""" """
_instance: Optional["MultiProcessingSharedPool"] = None _instance: Optional["MultiProcessingSharedPool"] = None
_lock = threading.Lock() _lock = threading.Lock()
_allow_child_init: bool = False
def __new__(cls) -> "MultiProcessingSharedPool": def __new__(cls, allow_child_init: bool = False) -> "MultiProcessingSharedPool":
""" """
Ensure singleton instance creation. Ensure singleton instance creation.
Parameters
----------
allow_child_init : bool, optional
If True, allows the pool to be initialized in a child process.
Default is False because data created in child processes won't be
shared with the parent process (each process has its own Manager).
Only set to True if you understand this limitation.
Returns Returns
------- -------
MultiProcessingSharedPool MultiProcessingSharedPool
@ -71,13 +82,20 @@ class MultiProcessingSharedPool:
if cls._instance is None: if cls._instance is None:
cls._instance = super(MultiProcessingSharedPool, cls).__new__(cls) cls._instance = super(MultiProcessingSharedPool, cls).__new__(cls)
cls._instance._initialized = False cls._instance._initialized = False
cls._allow_child_init = allow_child_init
return cls._instance return cls._instance
def __init__(self): def __init__(self, allow_child_init: bool = False):
""" """
Initialize the MultiProcessingSharedPool instance. Initialize the MultiProcessingSharedPool instance.
Uses a flag to ensure initialization only happens once. Uses a flag to ensure initialization only happens once.
Parameters
----------
allow_child_init : bool, optional
If True, allows the pool to be initialized in a child process.
Note: This parameter only takes effect on the first instantiation.
""" """
if getattr(self, "_initialized", False): if getattr(self, "_initialized", False):
return return
@ -103,6 +121,16 @@ class MultiProcessingSharedPool:
""" """
return cls() return cls()
def _is_child_process(self) -> bool:
"""Check if the current process is a child process."""
try:
# Python 3.8+ has parent_process()
return multiprocessing.parent_process() is not None
except AttributeError:
# Fallback for older Python versions
# Check if current pid differs from owner pid
return os.getpid() != self._owner_pid
def _ensure_initialized(self): def _ensure_initialized(self):
""" """
Lazy initialization of the multiprocessing Manager and shared dictionary. Lazy initialization of the multiprocessing Manager and shared dictionary.
@ -111,10 +139,24 @@ class MultiProcessingSharedPool:
------ ------
RuntimeError RuntimeError
If the multiprocessing Manager fails to start. If the multiprocessing Manager fails to start.
RuntimeError
If attempting to initialize in a child process without allow_child_init=True.
""" """
if self._shared_dict is None: if self._shared_dict is None:
with self._init_lock: with self._init_lock:
if self._shared_dict is None: if self._shared_dict is None:
# Check if we are in a child process
if self._is_child_process() and not self._allow_child_init:
raise RuntimeError(
"Cannot initialize MultiProcessingSharedPool in a child process. "
"Data created in child processes won't be shared with the parent process "
"because each process has its own Manager instance. "
"To suppress this error and proceed anyway, create the pool with "
"MultiProcessingSharedPool(allow_child_init=True). "
"Best practice: Initialize the pool in the main process before "
"starting any child processes."
)
try: try:
# Use the default context for manager # Use the default context for manager
self._manager = multiprocessing.Manager() self._manager = multiprocessing.Manager()

View File

@ -0,0 +1,123 @@
"""
子进程创建 pool 的行为控制测试
"""
import multiprocessing
import pytest
from mpsp.mpsp import MultiProcessingSharedPool
# 子进程创建 pool 相关的测试辅助函数
def worker_create_pool_in_child(result_queue):
"""子进程:尝试创建 pool应该失败"""
try:
# 重置单例以模拟子进程首次创建
MultiProcessingSharedPool._instance = None
pool = MultiProcessingSharedPool()
pool.put("child_key", "child_value")
result_queue.put(("success", "created"))
except RuntimeError as e:
result_queue.put(("error", str(e)))
def worker_create_pool_with_allow(result_queue):
"""子进程:使用 allow_child_init=True 创建 pool"""
try:
# 重置单例以模拟子进程首次创建
MultiProcessingSharedPool._instance = None
pool = MultiProcessingSharedPool(allow_child_init=True)
pool.put("child_key", "child_value")
result_queue.put(("success", "created_with_allow"))
except RuntimeError as e:
result_queue.put(("error", str(e)))
def worker_get_data(key, result_queue):
"""子进程:从共享池读取数据"""
pool = MultiProcessingSharedPool()
value = pool.get(key)
result_queue.put(value)
class TestChildProcessCreation:
"""测试子进程中创建 pool 的行为控制"""
def test_child_process_creation_blocked_by_default(self):
"""测试默认情况下子进程创建 pool 会被阻止"""
pool = MultiProcessingSharedPool()
pool.clear()
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_create_pool_in_child, args=(result_queue,)
)
p.start()
p.join()
status, message = result_queue.get()
assert status == "error"
assert (
"Cannot initialize MultiProcessingSharedPool in a child process" in message
)
def test_child_process_creation_allowed_with_flag(self):
"""测试使用 allow_child_init=True 允许子进程创建 pool"""
pool = MultiProcessingSharedPool()
pool.clear()
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_create_pool_with_allow, args=(result_queue,)
)
p.start()
p.join()
status, message = result_queue.get()
assert status == "success"
assert message == "created_with_allow"
def test_allow_child_init_parameter(self):
"""测试 allow_child_init 参数可以被正确设置"""
# 重置单例
MultiProcessingSharedPool._instance = None
# 默认情况下不允许
pool1 = MultiProcessingSharedPool()
assert MultiProcessingSharedPool._allow_child_init is False
# 重置单例
MultiProcessingSharedPool._instance = None
MultiProcessingSharedPool._allow_child_init = False
# 显式设置为允许
pool2 = MultiProcessingSharedPool(allow_child_init=True)
assert MultiProcessingSharedPool._allow_child_init is True
# 清理
MultiProcessingSharedPool._instance = None
MultiProcessingSharedPool._allow_child_init = False
def test_best_practice_main_process_init(self):
"""测试最佳实践:主进程先初始化,子进程后使用"""
# 重置单例
MultiProcessingSharedPool._instance = None
MultiProcessingSharedPool._allow_child_init = False
# 主进程先初始化
main_pool = MultiProcessingSharedPool()
main_pool._ensure_initialized()
main_pool.put("shared_key", "shared_value")
# 子进程使用(不创建新的)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_get_data, args=("shared_key", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == "shared_value"
# 清理
MultiProcessingSharedPool._instance = None