Initial commit.
This commit is contained in:
0
test/__init__.py
Normal file
0
test/__init__.py
Normal file
46
test/conftest.py
Normal file
46
test/conftest.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""
|
||||
Pytest 配置文件
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import multiprocessing
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
"""配置 pytest"""
|
||||
# 设置多进程启动方法(fork 在 Linux 上更快,spawn 在 Windows/macOS 上更稳定)
|
||||
try:
|
||||
multiprocessing.set_start_method("fork", force=True)
|
||||
except RuntimeError:
|
||||
# 如果已经设置过,忽略错误
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def reset_shared_pool():
|
||||
"""
|
||||
每个测试函数执行前清理共享池
|
||||
|
||||
这是一个自动使用的 fixture,确保每个测试都在干净的环境中运行
|
||||
"""
|
||||
from mpsp.mpsp import MultiProcessingSharedPool
|
||||
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
yield
|
||||
|
||||
# 测试结束后也清理
|
||||
pool.clear()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def shared_pool():
|
||||
"""
|
||||
提供共享池实例的 fixture
|
||||
|
||||
在整个测试会话中复用同一个实例(单例模式)
|
||||
"""
|
||||
from mpsp.mpsp import MultiProcessingSharedPool
|
||||
|
||||
return MultiProcessingSharedPool()
|
||||
324
test/test_basic.py
Normal file
324
test/test_basic.py
Normal file
@ -0,0 +1,324 @@
|
||||
"""
|
||||
基础功能测试 - 验证核心 API 正确性
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from mpsp.mpsp import MultiProcessingSharedPool
|
||||
|
||||
|
||||
class TestBasicOperations:
|
||||
"""测试基本操作"""
|
||||
|
||||
def test_singleton_pattern(self):
|
||||
"""测试单例模式 - 多次获取应为同一实例"""
|
||||
pool1 = MultiProcessingSharedPool()
|
||||
pool2 = MultiProcessingSharedPool.get_instance()
|
||||
pool3 = MultiProcessingSharedPool()
|
||||
|
||||
assert pool1 is pool2
|
||||
assert pool2 is pool3
|
||||
assert pool1 is pool3
|
||||
|
||||
def test_put_and_get_basic_types(self):
|
||||
"""测试基础类型数据的存取"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 整数
|
||||
pool.put("int_key", 42)
|
||||
assert pool.get("int_key") == 42
|
||||
|
||||
# 浮点数
|
||||
pool.put("float_key", 3.14159)
|
||||
assert abs(pool.get("float_key") - 3.14159) < 1e-10
|
||||
|
||||
# 字符串
|
||||
pool.put("str_key", "hello mpsp")
|
||||
assert pool.get("str_key") == "hello mpsp"
|
||||
|
||||
# 布尔值
|
||||
pool.put("bool_key", True)
|
||||
assert pool.get("bool_key") is True
|
||||
|
||||
# None
|
||||
pool.put("none_key", None)
|
||||
assert pool.get("none_key") is None
|
||||
|
||||
def test_put_and_get_collections(self):
|
||||
"""测试集合类型数据的存取"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 列表
|
||||
test_list = [1, 2, 3, "a", "b", "c"]
|
||||
pool.put("list_key", test_list)
|
||||
assert pool.get("list_key") == test_list
|
||||
|
||||
# 字典
|
||||
test_dict = {"name": "test", "value": 100, "nested": {"a": 1}}
|
||||
pool.put("dict_key", test_dict)
|
||||
assert pool.get("dict_key") == test_dict
|
||||
|
||||
# 元组
|
||||
test_tuple = (1, 2, 3)
|
||||
pool.put("tuple_key", test_tuple)
|
||||
assert pool.get("tuple_key") == test_tuple
|
||||
|
||||
# 集合
|
||||
test_set = {1, 2, 3}
|
||||
pool.put("set_key", test_set)
|
||||
assert pool.get("set_key") == test_set
|
||||
|
||||
def test_exists(self):
|
||||
"""测试 exists 方法"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
assert not pool.exists("nonexistent_key")
|
||||
|
||||
pool.put("existing_key", "value")
|
||||
assert pool.exists("existing_key")
|
||||
|
||||
pool.remove("existing_key")
|
||||
assert not pool.exists("existing_key")
|
||||
|
||||
def test_remove(self):
|
||||
"""测试 remove 方法"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 移除存在的 key
|
||||
pool.put("key_to_remove", "value")
|
||||
assert pool.remove("key_to_remove") is True
|
||||
assert not pool.exists("key_to_remove")
|
||||
|
||||
# 移除不存在的 key
|
||||
assert pool.remove("nonexistent_key") is False
|
||||
|
||||
def test_pop(self):
|
||||
"""测试 pop 方法"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# pop 存在的 key
|
||||
pool.put("key_to_pop", "popped_value")
|
||||
value = pool.pop("key_to_pop")
|
||||
assert value == "popped_value"
|
||||
assert not pool.exists("key_to_pop")
|
||||
|
||||
# pop 不存在的 key(带默认值)
|
||||
default_value = pool.pop("nonexistent_key", "default")
|
||||
assert default_value == "default"
|
||||
|
||||
# pop 不存在的 key(不带默认值)
|
||||
assert pool.pop("nonexistent_key") is None
|
||||
|
||||
def test_size(self):
|
||||
"""测试 size 方法"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
assert pool.size() == 0
|
||||
|
||||
pool.put("key1", "value1")
|
||||
assert pool.size() == 1
|
||||
|
||||
pool.put("key2", "value2")
|
||||
pool.put("key3", "value3")
|
||||
assert pool.size() == 3
|
||||
|
||||
pool.remove("key1")
|
||||
assert pool.size() == 2
|
||||
|
||||
pool.clear()
|
||||
assert pool.size() == 0
|
||||
|
||||
def test_keys(self):
|
||||
"""测试 keys 方法"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
assert pool.keys() == []
|
||||
|
||||
pool.put("key1", "value1")
|
||||
pool.put("key2", "value2")
|
||||
pool.put("key3", "value3")
|
||||
|
||||
keys = pool.keys()
|
||||
assert len(keys) == 3
|
||||
assert set(keys) == {"key1", "key2", "key3"}
|
||||
|
||||
def test_clear(self):
|
||||
"""测试 clear 方法"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
|
||||
pool.put("key1", "value1")
|
||||
pool.put("key2", "value2")
|
||||
|
||||
pool.clear()
|
||||
|
||||
assert pool.size() == 0
|
||||
assert pool.keys() == []
|
||||
assert not pool.exists("key1")
|
||||
|
||||
|
||||
class TestDictInterface:
|
||||
"""测试字典风格接口"""
|
||||
|
||||
def test_getitem(self):
|
||||
"""测试 __getitem__"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
pool.put("test_key", "test_value")
|
||||
assert pool["test_key"] == "test_value"
|
||||
|
||||
# 访问不存在的 key 应抛出 KeyError
|
||||
with pytest.raises(KeyError):
|
||||
_ = pool["nonexistent_key"]
|
||||
|
||||
def test_setitem(self):
|
||||
"""测试 __setitem__"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
pool["new_key"] = "new_value"
|
||||
assert pool.get("new_key") == "new_value"
|
||||
|
||||
def test_delitem(self):
|
||||
"""测试 __delitem__"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
pool.put("key_to_delete", "value")
|
||||
del pool["key_to_delete"]
|
||||
assert not pool.exists("key_to_delete")
|
||||
|
||||
# 删除不存在的 key 应抛出 KeyError
|
||||
with pytest.raises(KeyError):
|
||||
del pool["nonexistent_key"]
|
||||
|
||||
def test_contains(self):
|
||||
"""测试 __contains__ (in 操作符)"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
pool.put("existing_key", "value")
|
||||
|
||||
assert "existing_key" in pool
|
||||
assert "nonexistent_key" not in pool
|
||||
|
||||
|
||||
class TestContextManager:
|
||||
"""测试上下文管理器"""
|
||||
|
||||
def test_context_manager(self):
|
||||
"""测试 with 语句支持"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
with MultiProcessingSharedPool() as p:
|
||||
p.put("ctx_key", "ctx_value")
|
||||
assert p.get("ctx_key") == "ctx_value"
|
||||
|
||||
# 上下文退出后数据应该仍然存在(管理器未关闭)
|
||||
assert pool.get("ctx_key") == "ctx_value"
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""测试错误处理"""
|
||||
|
||||
def test_invalid_label_type_put(self):
|
||||
"""测试 put 时传入非法 label 类型"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
|
||||
with pytest.raises(TypeError, match="Label must be a string"):
|
||||
pool.put(123, "value")
|
||||
|
||||
with pytest.raises(TypeError, match="Label must be a string"):
|
||||
pool.put(None, "value")
|
||||
|
||||
def test_invalid_label_type_get(self):
|
||||
"""测试 get 时传入非法 label 类型"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
|
||||
with pytest.raises(TypeError, match="Label must be a string"):
|
||||
pool.get(123)
|
||||
|
||||
def test_invalid_label_type_pop(self):
|
||||
"""测试 pop 时传入非法 label 类型"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
|
||||
with pytest.raises(TypeError, match="Label must be a string"):
|
||||
pool.pop(123)
|
||||
|
||||
def test_get_with_default(self):
|
||||
"""测试 get 带默认值的场景"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# key 不存在时返回默认值
|
||||
assert pool.get("nonexistent", "default") == "default"
|
||||
assert pool.get("nonexistent", None) is None
|
||||
assert pool.get("nonexistent") is None
|
||||
|
||||
# key 存在时返回实际值
|
||||
pool.put("existing", "real_value")
|
||||
assert pool.get("existing", "default") == "real_value"
|
||||
|
||||
def test_special_label_names(self):
|
||||
"""测试特殊 label 名称"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 空字符串
|
||||
pool.put("", "empty_string_key")
|
||||
assert pool.get("") == "empty_string_key"
|
||||
|
||||
# 特殊字符
|
||||
special_keys = [
|
||||
"key with spaces",
|
||||
"key\twith\ttabs",
|
||||
"key\nwith\nnewlines",
|
||||
"key/with/slashes",
|
||||
"key.with.dots",
|
||||
"key:with:colons",
|
||||
"UPPERCASE_KEY",
|
||||
"mixedCase_Key",
|
||||
"unicode_中文_key",
|
||||
"emoji_😀_key",
|
||||
]
|
||||
|
||||
for key in special_keys:
|
||||
pool.put(key, f"value_for_{key}")
|
||||
assert pool.get(key) == f"value_for_{key}"
|
||||
|
||||
|
||||
class TestOverwrite:
|
||||
"""测试覆盖写入"""
|
||||
|
||||
def test_overwrite_value(self):
|
||||
"""测试覆盖相同 key 的值"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
pool.put("key", "original_value")
|
||||
assert pool.get("key") == "original_value"
|
||||
|
||||
pool.put("key", "new_value")
|
||||
assert pool.get("key") == "new_value"
|
||||
|
||||
# 不同类型覆盖
|
||||
pool.put("key", 12345)
|
||||
assert pool.get("key") == 12345
|
||||
|
||||
def test_overwrite_with_none(self):
|
||||
"""测试用 None 覆盖有值的 key"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
pool.put("key", "value")
|
||||
pool.put("key", None)
|
||||
|
||||
assert pool.get("key") is None
|
||||
assert pool.exists("key") # key 应该仍然存在
|
||||
652
test/test_edge_cases.py
Normal file
652
test/test_edge_cases.py
Normal file
@ -0,0 +1,652 @@
|
||||
"""
|
||||
边界与异常测试 - 验证鲁棒性
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
import time
|
||||
import pickle
|
||||
import pytest
|
||||
from mpsp.mpsp import MultiProcessingSharedPool
|
||||
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
|
||||
def worker_put_empty_key(result_queue):
|
||||
"""子进程:测试空字符串 key"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
try:
|
||||
pool.put("", "empty_value")
|
||||
result_queue.put(("success", pool.get("")))
|
||||
except Exception as e:
|
||||
result_queue.put(("error", str(e)))
|
||||
|
||||
|
||||
def worker_get_nonexistent(result_queue):
|
||||
"""子进程:获取不存在的 key"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
result = pool.get("definitely_nonexistent_key_12345")
|
||||
result_queue.put(result)
|
||||
|
||||
|
||||
def worker_put_large_object(key, data, result_queue):
|
||||
"""子进程:存储大对象"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
try:
|
||||
success = pool.put(key, data)
|
||||
result_queue.put(("success", success))
|
||||
except Exception as e:
|
||||
result_queue.put(("error", str(e)))
|
||||
|
||||
|
||||
# ==================== 测试类 ====================
|
||||
|
||||
|
||||
class TestEmptyAndNoneValues:
|
||||
"""测试空值和 None 处理"""
|
||||
|
||||
def test_put_empty_string_value(self):
|
||||
"""测试存储空字符串值"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
pool.put("empty_value_key", "")
|
||||
assert pool.get("empty_value_key") == ""
|
||||
|
||||
def test_put_none_value(self):
|
||||
"""测试存储 None 值"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
pool.put("none_value_key", None)
|
||||
assert pool.get("none_value_key") is None
|
||||
|
||||
def test_put_empty_list(self):
|
||||
"""测试存储空列表"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
pool.put("empty_list", [])
|
||||
assert pool.get("empty_list") == []
|
||||
|
||||
def test_put_empty_dict(self):
|
||||
"""测试存储空字典"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
pool.put("empty_dict", {})
|
||||
assert pool.get("empty_dict") == {}
|
||||
|
||||
def test_put_empty_tuple(self):
|
||||
"""测试存储空元组"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
pool.put("empty_tuple", ())
|
||||
assert pool.get("empty_tuple") == ()
|
||||
|
||||
def test_put_empty_set(self):
|
||||
"""测试存储空集合"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
pool.put("empty_set", set())
|
||||
assert pool.get("empty_set") == set()
|
||||
|
||||
|
||||
class TestSpecialLabelNames:
|
||||
"""测试特殊 label 名称"""
|
||||
|
||||
def test_empty_string_label(self):
|
||||
"""测试空字符串作为 label"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
pool.put("", "empty_key_value")
|
||||
assert pool.get("") == "empty_key_value"
|
||||
assert pool.exists("")
|
||||
|
||||
def test_empty_string_label_cross_process(self):
|
||||
"""测试空字符串 label 跨进程"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
pool.put("", "parent_empty_value")
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(target=worker_put_empty_key, args=(result_queue,))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
status, result = result_queue.get()
|
||||
# 子进程应该可以覆盖空字符串 key
|
||||
assert status == "success"
|
||||
assert result == "empty_value"
|
||||
|
||||
def test_unicode_label(self):
|
||||
"""测试 Unicode 字符作为 label"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
unicode_keys = [
|
||||
"中文键",
|
||||
"日本語キー",
|
||||
"한국어키",
|
||||
"emoji_😀",
|
||||
"special_©_®_™",
|
||||
"math_∑_∏_√",
|
||||
"arrows_→_←_↑_↓",
|
||||
]
|
||||
|
||||
for key in unicode_keys:
|
||||
pool.put(key, f"value_for_{key}")
|
||||
|
||||
for key in unicode_keys:
|
||||
assert pool.get(key) == f"value_for_{key}"
|
||||
|
||||
def test_whitespace_label(self):
|
||||
"""测试空白字符作为 label"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
whitespace_keys = [
|
||||
" ", # 单个空格
|
||||
" ", # 两个空格
|
||||
"\t", # Tab
|
||||
"\n", # 换行
|
||||
"\r\n", # Windows 换行
|
||||
" key_with_leading_space",
|
||||
"key_with_trailing_space ",
|
||||
" key_with_both_spaces ",
|
||||
"key\twith\ttabs",
|
||||
]
|
||||
|
||||
for key in whitespace_keys:
|
||||
pool.put(key, f"value_for_repr_{repr(key)}")
|
||||
|
||||
for key in whitespace_keys:
|
||||
assert pool.get(key) == f"value_for_repr_{repr(key)}"
|
||||
|
||||
def test_special_chars_label(self):
|
||||
"""测试特殊字符作为 label"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
special_keys = [
|
||||
"key.with.dots",
|
||||
"key/with/slashes",
|
||||
"key:with:colons",
|
||||
"key|with|pipes",
|
||||
"key*with*asterisks",
|
||||
"key?with?question",
|
||||
"key<with>brackets",
|
||||
"key[with]square",
|
||||
"key{with}curly",
|
||||
"key+with+plus",
|
||||
"key=with=equals",
|
||||
"key!with!exclamation",
|
||||
"key@with@at",
|
||||
"key#with#hash",
|
||||
"key$with$dollar",
|
||||
"key%with%percent",
|
||||
"key^with^caret",
|
||||
"key&with&ersand",
|
||||
"key'with'quotes",
|
||||
'key"with"double',
|
||||
"key`with`backtick",
|
||||
"key~with~tilde",
|
||||
"key-with-hyphens",
|
||||
"key_with_underscores",
|
||||
]
|
||||
|
||||
for key in special_keys:
|
||||
pool.put(key, f"value_for_{key}")
|
||||
|
||||
for key in special_keys:
|
||||
assert pool.get(key) == f"value_for_{key}"
|
||||
|
||||
def test_very_long_label(self):
|
||||
"""测试超长 label"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 1000 字符的 label
|
||||
long_key = "a" * 1000
|
||||
pool.put(long_key, "long_key_value")
|
||||
assert pool.get(long_key) == "long_key_value"
|
||||
|
||||
|
||||
class TestNonExistentKeys:
|
||||
"""测试不存在的 key 处理"""
|
||||
|
||||
def test_get_nonexistent(self):
|
||||
"""测试获取不存在的 key"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
result = pool.get("nonexistent_key_12345")
|
||||
assert result is None
|
||||
|
||||
def test_get_nonexistent_with_default(self):
|
||||
"""测试获取不存在的 key 带默认值"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
assert pool.get("nonexistent", "default") == "default"
|
||||
assert pool.get("nonexistent", 0) == 0
|
||||
assert pool.get("nonexistent", []) == []
|
||||
assert pool.get("nonexistent", {}) == {}
|
||||
|
||||
def test_get_nonexistent_cross_process(self):
|
||||
"""测试跨进程获取不存在的 key"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(target=worker_get_nonexistent, args=(result_queue,))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
assert result is None
|
||||
|
||||
def test_remove_nonexistent(self):
|
||||
"""测试删除不存在的 key"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
assert pool.remove("nonexistent_key") is False
|
||||
|
||||
def test_pop_nonexistent(self):
|
||||
"""测试弹出不存在的 key"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
assert pool.pop("nonexistent_key") is None
|
||||
assert pool.pop("nonexistent_key", "default") == "default"
|
||||
|
||||
def test_exists_nonexistent(self):
|
||||
"""测试检查不存在的 key"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
assert pool.exists("nonexistent_key") is False
|
||||
|
||||
|
||||
class TestLargeObjects:
|
||||
"""测试大对象序列化"""
|
||||
|
||||
def test_large_list(self):
|
||||
"""测试大型列表"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
large_list = list(range(100000))
|
||||
pool.put("large_list", large_list)
|
||||
|
||||
retrieved = pool.get("large_list")
|
||||
assert len(retrieved) == 100000
|
||||
assert retrieved[0] == 0
|
||||
assert retrieved[99999] == 99999
|
||||
|
||||
def test_large_dict(self):
|
||||
"""测试大型字典"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
large_dict = {f"key_{i}": f"value_{i}" for i in range(10000)}
|
||||
pool.put("large_dict", large_dict)
|
||||
|
||||
retrieved = pool.get("large_dict")
|
||||
assert len(retrieved) == 10000
|
||||
assert retrieved["key_0"] == "value_0"
|
||||
assert retrieved["key_9999"] == "value_9999"
|
||||
|
||||
def test_large_string(self):
|
||||
"""测试大型字符串"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
large_string = "x" * 1000000 # 1MB 字符串
|
||||
pool.put("large_string", large_string)
|
||||
|
||||
retrieved = pool.get("large_string")
|
||||
assert len(retrieved) == 1000000
|
||||
assert retrieved[0] == "x"
|
||||
assert retrieved[-1] == "x"
|
||||
|
||||
def test_deeply_nested_structure(self):
|
||||
"""测试深度嵌套结构"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 创建深度嵌套的字典
|
||||
depth = 50
|
||||
nested = "bottom"
|
||||
for i in range(depth):
|
||||
nested = {"level": i, "nested": nested}
|
||||
|
||||
pool.put("deep_nested", nested)
|
||||
|
||||
retrieved = pool.get("deep_nested")
|
||||
# 验证嵌套深度
|
||||
current = retrieved
|
||||
for i in range(depth):
|
||||
assert current["level"] == depth - 1 - i
|
||||
current = current["nested"]
|
||||
assert current == "bottom"
|
||||
|
||||
def test_large_object_cross_process(self):
|
||||
"""测试跨进程传递大对象"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
large_data = {"items": list(range(10000)), "name": "large_test"}
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_put_large_object,
|
||||
args=("large_cross", large_data, result_queue),
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
status, result = result_queue.get()
|
||||
assert status == "success"
|
||||
assert result is True
|
||||
|
||||
|
||||
class TestCircularReferences:
|
||||
"""测试循环引用"""
|
||||
|
||||
def test_circular_list(self):
|
||||
"""测试列表中的循环引用"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 创建循环引用列表
|
||||
circular = [1, 2, 3]
|
||||
circular.append(circular) # 循环引用
|
||||
|
||||
pool.put("circular_list", circular)
|
||||
|
||||
retrieved = pool.get("circular_list")
|
||||
assert retrieved[0] == 1
|
||||
assert retrieved[1] == 2
|
||||
assert retrieved[2] == 3
|
||||
# 循环引用应该被正确处理
|
||||
assert retrieved[3] is not None
|
||||
|
||||
def test_circular_dict(self):
|
||||
"""测试字典中的循环引用"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 创建循环引用字典
|
||||
circular = {"a": 1, "b": 2}
|
||||
circular["self"] = circular # 循环引用
|
||||
|
||||
pool.put("circular_dict", circular)
|
||||
|
||||
retrieved = pool.get("circular_dict")
|
||||
assert retrieved["a"] == 1
|
||||
assert retrieved["b"] == 2
|
||||
# 循环引用应该被正确处理
|
||||
assert "self" in retrieved
|
||||
|
||||
|
||||
class TestBinaryData:
|
||||
"""测试二进制数据"""
|
||||
|
||||
def test_bytes_data(self):
|
||||
"""测试字节数据"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
binary_data = b"\x00\x01\x02\x03\xff\xfe\xfd\xfc"
|
||||
pool.put("binary_data", binary_data)
|
||||
|
||||
retrieved = pool.get("binary_data")
|
||||
assert retrieved == binary_data
|
||||
|
||||
def test_large_binary_data(self):
|
||||
"""测试大型二进制数据"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
binary_data = bytes(range(256)) * 1000 # 256KB
|
||||
pool.put("large_binary", binary_data)
|
||||
|
||||
retrieved = pool.get("large_binary")
|
||||
assert retrieved == binary_data
|
||||
|
||||
def test_bytearray_data(self):
|
||||
"""测试 bytearray 数据"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
ba = bytearray(b"\x00\x01\x02\x03")
|
||||
pool.put("bytearray_data", ba)
|
||||
|
||||
retrieved = pool.get("bytearray_data")
|
||||
assert retrieved == ba
|
||||
|
||||
|
||||
class TestMixedTypes:
|
||||
"""测试混合类型数据"""
|
||||
|
||||
def test_heterogeneous_list(self):
|
||||
"""测试异构列表"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
mixed_list = [
|
||||
1, # int
|
||||
3.14, # float
|
||||
"string", # str
|
||||
True, # bool
|
||||
None, # NoneType
|
||||
[1, 2, 3], # list
|
||||
{"a": 1}, # dict
|
||||
(1, 2), # tuple
|
||||
{1, 2, 3}, # set
|
||||
b"binary", # bytes
|
||||
]
|
||||
|
||||
pool.put("mixed_list", mixed_list)
|
||||
|
||||
retrieved = pool.get("mixed_list")
|
||||
assert retrieved[0] == 1
|
||||
assert abs(retrieved[1] - 3.14) < 1e-10
|
||||
assert retrieved[2] == "string"
|
||||
assert retrieved[3] is True
|
||||
assert retrieved[4] is None
|
||||
assert retrieved[5] == [1, 2, 3]
|
||||
assert retrieved[6] == {"a": 1}
|
||||
assert retrieved[7] == (1, 2)
|
||||
assert retrieved[8] == {1, 2, 3}
|
||||
assert retrieved[9] == b"binary"
|
||||
|
||||
def test_heterogeneous_dict(self):
|
||||
"""测试异构字典"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
mixed_dict = {
|
||||
"int_key": 42,
|
||||
"float_key": 3.14,
|
||||
"str_key": "hello",
|
||||
"bool_key": True,
|
||||
"none_key": None,
|
||||
"list_key": [1, 2, 3],
|
||||
"dict_key": {"nested": "value"},
|
||||
"tuple_key": (1, 2, 3),
|
||||
}
|
||||
|
||||
pool.put("mixed_dict", mixed_dict)
|
||||
|
||||
retrieved = pool.get("mixed_dict")
|
||||
for key, value in mixed_dict.items():
|
||||
if isinstance(value, float):
|
||||
assert abs(retrieved[key] - value) < 1e-10
|
||||
else:
|
||||
assert retrieved[key] == value
|
||||
|
||||
|
||||
class TestConcurrentAccess:
|
||||
"""测试并发访问稳定性"""
|
||||
|
||||
|
||||
def worker_stress_test(key_prefix, iterations, result_queue):
|
||||
"""子进程:压力测试"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
errors = []
|
||||
|
||||
for i in range(iterations):
|
||||
try:
|
||||
key = f"{key_prefix}_{i}"
|
||||
pool.put(key, f"value_{i}")
|
||||
value = pool.get(key)
|
||||
if value != f"value_{i}":
|
||||
errors.append(f"Value mismatch at {key}")
|
||||
pool.remove(key)
|
||||
except Exception as e:
|
||||
errors.append(str(e))
|
||||
|
||||
result_queue.put(errors)
|
||||
|
||||
class TestConcurrentAccess:
|
||||
"""测试并发访问稳定性"""
|
||||
|
||||
def test_stress_concurrent_writes(self):
|
||||
"""压力测试:并发写入"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
num_processes = 4
|
||||
iterations = 100
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
processes = []
|
||||
|
||||
for i in range(num_processes):
|
||||
p = multiprocessing.Process(
|
||||
target=worker_stress_test,
|
||||
args=(f"stress_{i}", iterations, result_queue),
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
# 收集所有错误
|
||||
all_errors = []
|
||||
for _ in range(num_processes):
|
||||
all_errors.extend(result_queue.get())
|
||||
|
||||
# 应该没有错误
|
||||
assert len(all_errors) == 0, f"Errors occurred: {all_errors}"
|
||||
|
||||
def test_rapid_put_get_cycle(self):
|
||||
"""测试快速 put-get 循环"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
for i in range(1000):
|
||||
pool.put("rapid_key", f"value_{i}")
|
||||
value = pool.get("rapid_key")
|
||||
assert value == f"value_{i}"
|
||||
|
||||
def test_rapid_key_creation_deletion(self):
|
||||
"""测试快速创建和删除 key"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
for i in range(100):
|
||||
key = f"temp_key_{i}"
|
||||
pool.put(key, f"temp_value_{i}")
|
||||
assert pool.exists(key)
|
||||
pool.remove(key)
|
||||
assert not pool.exists(key)
|
||||
|
||||
|
||||
class TestErrorRecovery:
|
||||
"""测试错误恢复能力"""
|
||||
|
||||
def test_put_after_error(self):
|
||||
"""测试错误后可以继续 put"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 尝试使用非法 key 类型
|
||||
try:
|
||||
pool.put(123, "value") # 应该抛出 TypeError
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
# 应该可以继续正常使用
|
||||
pool.put("valid_key", "valid_value")
|
||||
assert pool.get("valid_key") == "valid_value"
|
||||
|
||||
def test_get_after_nonexistent(self):
|
||||
"""测试获取不存在的 key 后可以继续使用"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 获取不存在的 key
|
||||
result = pool.get("nonexistent")
|
||||
assert result is None
|
||||
|
||||
# 应该可以继续正常使用
|
||||
pool.put("new_key", "new_value")
|
||||
assert pool.get("new_key") == "new_value"
|
||||
|
||||
def test_multiple_singleton_access(self):
|
||||
"""测试多次获取单例后访问"""
|
||||
pool1 = MultiProcessingSharedPool()
|
||||
pool1.put("key1", "value1")
|
||||
|
||||
pool2 = MultiProcessingSharedPool()
|
||||
pool2.put("key2", "value2")
|
||||
|
||||
pool3 = MultiProcessingSharedPool.get_instance()
|
||||
pool3.put("key3", "value3")
|
||||
|
||||
# 所有实例应该看到相同的数据
|
||||
assert pool1.get("key1") == "value1"
|
||||
assert pool1.get("key2") == "value2"
|
||||
assert pool1.get("key3") == "value3"
|
||||
|
||||
|
||||
class TestCleanup:
|
||||
"""测试清理功能"""
|
||||
|
||||
def test_clear_after_multiple_puts(self):
|
||||
"""测试多次 put 后 clear"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
for i in range(100):
|
||||
pool.put(f"key_{i}", f"value_{i}")
|
||||
|
||||
assert pool.size() == 100
|
||||
|
||||
pool.clear()
|
||||
|
||||
assert pool.size() == 0
|
||||
assert pool.keys() == []
|
||||
|
||||
def test_remove_all_keys_one_by_one(self):
|
||||
"""测试逐个删除所有 key"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
keys = [f"key_{i}" for i in range(50)]
|
||||
for key in keys:
|
||||
pool.put(key, f"value_for_{key}")
|
||||
|
||||
for key in keys:
|
||||
assert pool.remove(key) is True
|
||||
|
||||
assert pool.size() == 0
|
||||
523
test/test_functions.py
Normal file
523
test/test_functions.py
Normal file
@ -0,0 +1,523 @@
|
||||
"""
|
||||
函数序列化测试 - 验证 cloudpickle 集成
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
import time
|
||||
import math
|
||||
import pytest
|
||||
from mpsp.mpsp import MultiProcessingSharedPool
|
||||
|
||||
|
||||
# ==================== 模块级别的普通函数 ====================
|
||||
|
||||
|
||||
def simple_function(x):
|
||||
"""简单的加法函数"""
|
||||
return x + 1
|
||||
|
||||
|
||||
def multiply_function(a, b):
|
||||
"""乘法函数"""
|
||||
return a * b
|
||||
|
||||
|
||||
def function_with_default(x, y=10):
|
||||
"""带默认参数的函数"""
|
||||
return x + y
|
||||
|
||||
|
||||
def function_with_kwargs(*args, **kwargs):
|
||||
"""带可变参数的函数"""
|
||||
return sum(args) + sum(kwargs.values())
|
||||
|
||||
|
||||
def recursive_function(n):
|
||||
"""递归函数"""
|
||||
if n <= 1:
|
||||
return 1
|
||||
return n * recursive_function(n - 1)
|
||||
|
||||
|
||||
def closure_factory(base):
|
||||
"""闭包工厂函数"""
|
||||
|
||||
def inner(x):
|
||||
return x + base
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
# ==================== 辅助函数(模块级别定义)====================
|
||||
|
||||
|
||||
def worker_execute_function(key, result_queue):
|
||||
"""子进程:获取函数并执行"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
func = pool.get(key)
|
||||
if func is None:
|
||||
result_queue.put(None)
|
||||
return
|
||||
|
||||
# 执行函数(根据不同的测试函数传入不同参数)
|
||||
try:
|
||||
if key == "simple_func":
|
||||
result = func(5)
|
||||
elif key == "multiply_func":
|
||||
result = func(3, 4)
|
||||
elif key == "default_func":
|
||||
result = func(5)
|
||||
elif key == "kwargs_func":
|
||||
result = func(1, 2, 3, a=4, b=5)
|
||||
elif key == "recursive_func":
|
||||
result = func(5)
|
||||
elif key == "closure_func":
|
||||
result = func(10)
|
||||
elif key == "lambda_func":
|
||||
result = func(7)
|
||||
elif key == "lambda_with_capture":
|
||||
result = func()
|
||||
elif key == "nested_func":
|
||||
result = func(3)
|
||||
else:
|
||||
result = func()
|
||||
result_queue.put(result)
|
||||
except Exception as e:
|
||||
result_queue.put(f"ERROR: {e}")
|
||||
|
||||
|
||||
def worker_execute_lambda_with_arg(key, arg, result_queue):
|
||||
"""子进程:获取 lambda 并执行,传入参数"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
func = pool.get(key)
|
||||
if func is None:
|
||||
result_queue.put(None)
|
||||
return
|
||||
result_queue.put(func(arg))
|
||||
|
||||
|
||||
def get_lambda_description(func):
|
||||
"""获取 lambda 函数的描述字符串"""
|
||||
try:
|
||||
return func.__name__
|
||||
except AttributeError:
|
||||
return str(func)
|
||||
|
||||
|
||||
# ==================== 测试类 ====================
|
||||
|
||||
|
||||
class TestNormalFunctions:
|
||||
"""测试普通函数的序列化和反序列化"""
|
||||
|
||||
def test_simple_function(self):
|
||||
"""测试简单函数的传递"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 存储函数
|
||||
result = pool.put("simple_func", simple_function)
|
||||
assert result is True
|
||||
|
||||
# 当前进程验证
|
||||
retrieved = pool.get("simple_func")
|
||||
assert retrieved is not None
|
||||
assert retrieved(5) == 6
|
||||
assert retrieved(10) == 11
|
||||
|
||||
def test_simple_function_cross_process(self):
|
||||
"""测试简单函数的跨进程传递"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 父进程存储函数
|
||||
pool.put("simple_func", simple_function)
|
||||
|
||||
# 子进程获取并执行
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_execute_function, args=("simple_func", result_queue)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
assert result == 6 # simple_function(5) = 5 + 1
|
||||
|
||||
def test_function_with_multiple_args(self):
|
||||
"""测试多参数函数"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
pool.put("multiply_func", multiply_function)
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_execute_function, args=("multiply_func", result_queue)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
assert result == 12 # multiply_function(3, 4) = 12
|
||||
|
||||
def test_function_with_default_args(self):
|
||||
"""测试带默认参数的函数"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
pool.put("default_func", function_with_default)
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_execute_function, args=("default_func", result_queue)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
assert result == 15 # function_with_default(5) = 5 + 10
|
||||
|
||||
def test_function_with_kwargs(self):
|
||||
"""测试带 **kwargs 的函数"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
pool.put("kwargs_func", function_with_kwargs)
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_execute_function, args=("kwargs_func", result_queue)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
assert result == 15 # sum(1,2,3) + sum(4,5) = 6 + 9 = 15
|
||||
|
||||
def test_recursive_function(self):
|
||||
"""测试递归函数"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
pool.put("recursive_func", recursive_function)
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_execute_function, args=("recursive_func", result_queue)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
assert result == 120 # 5! = 120
|
||||
|
||||
|
||||
class TestLambdaFunctions:
|
||||
"""测试 Lambda 函数的序列化和反序列化"""
|
||||
|
||||
def test_simple_lambda(self):
|
||||
"""测试简单 lambda 函数"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
simple_lambda = lambda x: x * 2
|
||||
result = pool.put("lambda_func", simple_lambda)
|
||||
assert result is True
|
||||
|
||||
# 当前进程验证
|
||||
retrieved = pool.get("lambda_func")
|
||||
assert retrieved(5) == 10
|
||||
assert retrieved(7) == 14
|
||||
|
||||
def test_simple_lambda_cross_process(self):
|
||||
"""测试简单 lambda 的跨进程传递"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
simple_lambda = lambda x: x * 3
|
||||
pool.put("lambda_func", simple_lambda)
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_execute_function, args=("lambda_func", result_queue)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
assert result == 21 # lambda(7) = 7 * 3 = 21
|
||||
|
||||
def test_lambda_with_capture(self):
|
||||
"""测试捕获外部变量的 lambda"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
captured_value = 100
|
||||
capturing_lambda = lambda: captured_value + 1
|
||||
|
||||
pool.put("lambda_with_capture", capturing_lambda)
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_execute_function, args=("lambda_with_capture", result_queue)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
assert result == 101 # captured_value + 1 = 101
|
||||
|
||||
def test_lambda_in_list_comprehension(self):
|
||||
"""测试在列表推导式中创建的 lambda"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 创建多个 lambda
|
||||
lambdas = [(lambda x, i=i: x + i) for i in range(5)]
|
||||
|
||||
for i, lam in enumerate(lambdas):
|
||||
pool.put(f"lambda_{i}", lam)
|
||||
|
||||
# 验证每个 lambda 都能正确捕获各自的 i
|
||||
for i in range(5):
|
||||
retrieved = pool.get(f"lambda_{i}")
|
||||
assert retrieved(10) == 10 + i
|
||||
|
||||
def test_complex_lambda(self):
|
||||
"""测试复杂的 lambda 表达式"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
complex_lambda = lambda x, y: (x**2 + y**2) ** 0.5
|
||||
pool.put("complex_lambda", complex_lambda)
|
||||
|
||||
# 子进程验证
|
||||
def worker_execute_complex_lambda(key, x, y, result_queue):
|
||||
pool = MultiProcessingSharedPool()
|
||||
func = pool.get(key)
|
||||
result_queue.put(func(x, y))
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_execute_complex_lambda,
|
||||
args=("complex_lambda", 3, 4, result_queue),
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
assert abs(result - 5.0) < 1e-10 # sqrt(3^2 + 4^2) = 5
|
||||
|
||||
|
||||
class TestNestedFunctions:
|
||||
"""测试嵌套函数(在函数内部定义的函数)"""
|
||||
|
||||
def test_nested_function(self):
|
||||
"""测试嵌套函数"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
def outer_function(x):
|
||||
def inner_function(y):
|
||||
return y * y
|
||||
|
||||
return inner_function(x) + x
|
||||
|
||||
pool.put("nested_func", outer_function)
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_execute_function, args=("nested_func", result_queue)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
assert result == 12 # outer_function(3) = 3*3 + 3 = 12
|
||||
|
||||
def test_closure_function(self):
|
||||
"""测试闭包函数"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
closure_func = closure_factory(100)
|
||||
pool.put("closure_func", closure_func)
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_execute_function, args=("closure_func", result_queue)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
assert result == 110 # closure_func(10) = 10 + 100
|
||||
|
||||
def test_multiple_closures(self):
|
||||
"""测试多个闭包"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
closures = [closure_factory(i) for i in range(5)]
|
||||
for i, closure in enumerate(closures):
|
||||
pool.put(f"closure_{i}", closure)
|
||||
|
||||
# 验证每个闭包捕获的值不同
|
||||
for i in range(5):
|
||||
retrieved = pool.get(f"closure_{i}")
|
||||
assert retrieved(10) == 10 + i
|
||||
|
||||
|
||||
class TestClassMethods:
|
||||
"""测试类方法的序列化"""
|
||||
|
||||
def test_static_method(self):
|
||||
"""测试静态方法"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
class Calculator:
|
||||
@staticmethod
|
||||
def add(x, y):
|
||||
return x + y
|
||||
|
||||
@staticmethod
|
||||
def multiply(x, y):
|
||||
return x * y
|
||||
|
||||
pool.put("static_add", Calculator.add)
|
||||
pool.put("static_multiply", Calculator.multiply)
|
||||
|
||||
# 验证静态方法
|
||||
add_func = pool.get("static_add")
|
||||
multiply_func = pool.get("static_multiply")
|
||||
|
||||
assert add_func(2, 3) == 5
|
||||
assert multiply_func(2, 3) == 6
|
||||
|
||||
def test_class_method(self):
|
||||
"""测试类方法"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
class Counter:
|
||||
count = 0
|
||||
|
||||
@classmethod
|
||||
def increment(cls):
|
||||
cls.count += 1
|
||||
return cls.count
|
||||
|
||||
# 注意:类方法通常不能被 cloudpickle 正确序列化
|
||||
# 因为它依赖于类定义
|
||||
result = pool.put("class_method", Counter.increment)
|
||||
# 如果成功存储,尝试执行
|
||||
if result:
|
||||
try:
|
||||
method = pool.get("class_method")
|
||||
# 类方法在反序列化后可能无法正常工作
|
||||
# 这取决于 cloudpickle 的实现
|
||||
except Exception:
|
||||
pass # 预期可能失败
|
||||
|
||||
|
||||
class TestBuiltInFunctions:
|
||||
"""测试内置函数的序列化"""
|
||||
|
||||
def test_builtin_functions(self):
|
||||
"""测试 Python 内置函数"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 大多数内置函数可以用标准 pickle 序列化
|
||||
pool.put("builtin_sum", sum)
|
||||
pool.put("builtin_max", max)
|
||||
pool.put("builtin_min", min)
|
||||
pool.put("builtin_len", len)
|
||||
|
||||
# 验证
|
||||
assert pool.get("builtin_sum")([1, 2, 3]) == 6
|
||||
assert pool.get("builtin_max")([1, 2, 3]) == 3
|
||||
assert pool.get("builtin_min")([1, 2, 3]) == 1
|
||||
assert pool.get("builtin_len")([1, 2, 3]) == 3
|
||||
|
||||
def test_math_functions(self):
|
||||
"""测试 math 模块函数"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
pool.put("math_sqrt", math.sqrt)
|
||||
pool.put("math_sin", math.sin)
|
||||
pool.put("math_cos", math.cos)
|
||||
|
||||
# 验证
|
||||
assert abs(pool.get("math_sqrt")(16) - 4.0) < 1e-10
|
||||
assert abs(pool.get("math_sin")(0) - 0.0) < 1e-10
|
||||
assert abs(pool.get("math_cos")(0) - 1.0) < 1e-10
|
||||
|
||||
|
||||
class TestFunctionReturnValues:
|
||||
"""测试函数作为返回值"""
|
||||
|
||||
def test_function_returned_from_function(self):
|
||||
"""测试返回函数的函数"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
def create_multiplier(factor):
|
||||
return lambda x: x * factor
|
||||
|
||||
pool.put("create_multiplier", create_multiplier)
|
||||
|
||||
# 在子进程中获取并使用
|
||||
def worker_get_multiplier(result_queue):
|
||||
pool = MultiProcessingSharedPool()
|
||||
factory = pool.get("create_multiplier")
|
||||
multiplier_func = factory(5) # 创建一个乘以 5 的函数
|
||||
result_queue.put(multiplier_func(10))
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(target=worker_get_multiplier, args=(result_queue,))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
assert result == 50 # 10 * 5 = 50
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""测试函数序列化的错误处理"""
|
||||
|
||||
def test_unpicklable_function_fallback(self):
|
||||
"""测试无法序列化的函数回退到 cloudpickle"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 创建局部函数(无法被标准 pickle 序列化)
|
||||
def local_function(x):
|
||||
return x**2
|
||||
|
||||
# 应该通过 cloudpickle 成功存储
|
||||
result = pool.put("local_func", local_function)
|
||||
assert result is True
|
||||
|
||||
# 验证可以正确执行
|
||||
retrieved = pool.get("local_func")
|
||||
assert retrieved(5) == 25
|
||||
|
||||
def test_function_with_unpicklable_capture(self):
|
||||
"""测试捕获不可序列化对象的函数"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 捕获文件对象(不可序列化)
|
||||
try:
|
||||
with open(__file__, "r") as f:
|
||||
file_capturing_lambda = lambda: f.read()
|
||||
|
||||
# 尝试存储应该失败
|
||||
result = pool.put("file_lambda", file_capturing_lambda)
|
||||
# 如果 cloudpickle 支持,验证是否能正确失败
|
||||
except Exception:
|
||||
pass # 预期可能失败
|
||||
446
test/test_multiprocess.py
Normal file
446
test/test_multiprocess.py
Normal file
@ -0,0 +1,446 @@
|
||||
"""
|
||||
多进程数据共享测试 - 验证跨进程数据同步能力
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
import time
|
||||
import pytest
|
||||
from mpsp.mpsp import MultiProcessingSharedPool
|
||||
|
||||
|
||||
# ==================== 辅助函数(需要在模块级别定义以便多进程使用)====================
|
||||
|
||||
|
||||
def worker_put_data(key, value):
|
||||
"""子进程:往共享池写入数据"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.put(key, value)
|
||||
return True
|
||||
|
||||
|
||||
def worker_get_data(key, result_queue):
|
||||
"""子进程:从共享池读取数据并放入结果队列"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
value = pool.get(key)
|
||||
result_queue.put(value)
|
||||
|
||||
|
||||
def worker_check_exists(key, result_queue):
|
||||
"""子进程:检查 key 是否存在"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
result_queue.put(pool.exists(key))
|
||||
|
||||
|
||||
def worker_modify_data(key, new_value, result_queue):
|
||||
"""子进程:修改数据并返回旧值"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
old_value = pool.get(key)
|
||||
pool.put(key, new_value)
|
||||
result_queue.put(old_value)
|
||||
|
||||
|
||||
def worker_wait_and_get(key, wait_time, result_queue):
|
||||
"""子进程:等待一段时间后读取数据"""
|
||||
time.sleep(wait_time)
|
||||
pool = MultiProcessingSharedPool()
|
||||
result_queue.put(pool.get(key))
|
||||
|
||||
|
||||
def worker_increment_counter(key, iterations):
|
||||
"""子进程:对计数器进行递增"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
for _ in range(iterations):
|
||||
# 注意:这不是原子操作,仅用于测试并发访问
|
||||
current = pool.get(key, 0)
|
||||
pool.put(key, current + 1)
|
||||
|
||||
|
||||
def worker_pop_data(key, result_queue):
|
||||
"""子进程:弹出数据"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
value = pool.pop(key)
|
||||
result_queue.put(value)
|
||||
|
||||
|
||||
# ==================== 测试类 ====================
|
||||
|
||||
|
||||
class TestParentChildProcess:
|
||||
"""测试父子进程间数据传递"""
|
||||
|
||||
def test_parent_write_child_read(self):
|
||||
"""测试父进程写入,子进程读取"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 父进程写入数据
|
||||
pool.put("shared_key", "shared_value")
|
||||
|
||||
# 子进程读取数据
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_get_data, args=("shared_key", result_queue)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
assert result == "shared_value"
|
||||
|
||||
def test_child_write_parent_read(self):
|
||||
"""测试子进程写入,父进程读取"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 子进程写入数据
|
||||
p = multiprocessing.Process(
|
||||
target=worker_put_data, args=("child_key", "child_value")
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
# 父进程读取数据(需要短暂等待以确保数据同步)
|
||||
time.sleep(0.1)
|
||||
result = pool.get("child_key")
|
||||
assert result == "child_value"
|
||||
|
||||
def test_parent_child_data_isolation(self):
|
||||
"""测试父子进程数据隔离 - 验证 manager.dict 的同步机制"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 父进程写入初始数据
|
||||
pool.put("isolation_test", "parent_value")
|
||||
|
||||
# 子进程读取并修改
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_modify_data,
|
||||
args=("isolation_test", "child_value", result_queue),
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
# 验证子进程读取到了父进程的值
|
||||
child_read = result_queue.get()
|
||||
assert child_read == "parent_value"
|
||||
|
||||
# 验证父进程可以看到子进程的修改
|
||||
time.sleep(0.1)
|
||||
parent_read = pool.get("isolation_test")
|
||||
assert parent_read == "child_value"
|
||||
|
||||
|
||||
class TestMultipleChildProcesses:
|
||||
"""测试多个子进程间的数据共享"""
|
||||
|
||||
def test_multiple_children_write_same_key(self):
|
||||
"""测试多个子进程写入同一个 key(后面的覆盖前面的)"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
processes = []
|
||||
for i in range(5):
|
||||
p = multiprocessing.Process(
|
||||
target=worker_put_data, args=("shared_key", f"value_{i}")
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
# 短暂等待确保所有写入完成
|
||||
time.sleep(0.1)
|
||||
|
||||
# 验证有一个值被成功写入
|
||||
result = pool.get("shared_key")
|
||||
assert result.startswith("value_")
|
||||
|
||||
def test_multiple_children_write_different_keys(self):
|
||||
"""测试多个子进程写入不同的 key"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
num_processes = 5
|
||||
processes = []
|
||||
|
||||
for i in range(num_processes):
|
||||
p = multiprocessing.Process(
|
||||
target=worker_put_data, args=(f"key_{i}", f"value_{i}")
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
# 短暂等待确保所有写入完成
|
||||
time.sleep(0.1)
|
||||
|
||||
# 验证所有值都被成功写入
|
||||
for i in range(num_processes):
|
||||
assert pool.get(f"key_{i}") == f"value_{i}"
|
||||
|
||||
def test_multiple_children_read_same_key(self):
|
||||
"""测试多个子进程读取同一个 key"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 父进程写入数据
|
||||
pool.put("shared_key", "shared_value")
|
||||
|
||||
# 多个子进程读取
|
||||
num_processes = 5
|
||||
result_queue = multiprocessing.Queue()
|
||||
processes = []
|
||||
|
||||
for _ in range(num_processes):
|
||||
p = multiprocessing.Process(
|
||||
target=worker_get_data, args=("shared_key", result_queue)
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
# 收集所有结果
|
||||
results = []
|
||||
for _ in range(num_processes):
|
||||
results.append(result_queue.get())
|
||||
|
||||
# 所有子进程都应该读取到相同的值
|
||||
assert all(r == "shared_value" for r in results)
|
||||
|
||||
def test_concurrent_exists_check(self):
|
||||
"""测试并发检查 key 是否存在"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 写入一些数据
|
||||
for i in range(5):
|
||||
pool.put(f"key_{i}", f"value_{i}")
|
||||
|
||||
# 多个子进程并发检查
|
||||
result_queue = multiprocessing.Queue()
|
||||
processes = []
|
||||
|
||||
for i in range(10):
|
||||
key = f"key_{i % 7}" # 有些 key 存在,有些不存在
|
||||
p = multiprocessing.Process(
|
||||
target=worker_check_exists, args=(key, result_queue)
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
# 收集结果
|
||||
results = []
|
||||
for _ in range(10):
|
||||
results.append(result_queue.get())
|
||||
|
||||
# 前 5 个应该存在 (key_0 到 key_4),后 5 个不存在 (key_5, key_6 重复检查)
|
||||
assert sum(results) >= 5 # 至少 5 个存在
|
||||
|
||||
|
||||
# 进程池测试需要在模块级别定义 worker 函数
|
||||
def _pool_worker_map(args):
|
||||
"""进程池 map 操作的 worker"""
|
||||
idx, key = args
|
||||
shared_pool = MultiProcessingSharedPool()
|
||||
value = shared_pool.get(key)
|
||||
return idx, value
|
||||
|
||||
|
||||
def _pool_worker_apply(key):
|
||||
"""进程池 apply_async 的 worker"""
|
||||
shared_pool = MultiProcessingSharedPool()
|
||||
return shared_pool.get(key)
|
||||
|
||||
|
||||
class TestProcessPool:
|
||||
"""测试在进程池中使用"""
|
||||
|
||||
def test_pool_map_with_shared_data(self):
|
||||
"""测试在进程池 map 操作中使用共享数据"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 写入测试数据
|
||||
for i in range(5):
|
||||
pool.put(f"input_{i}", i * 10)
|
||||
|
||||
# 使用进程池
|
||||
with multiprocessing.Pool(processes=3) as process_pool:
|
||||
results = process_pool.map(
|
||||
_pool_worker_map, [(i, f"input_{i}") for i in range(5)]
|
||||
)
|
||||
|
||||
# 验证结果
|
||||
results_dict = {idx: val for idx, val in results}
|
||||
for i in range(5):
|
||||
assert results_dict[i] == i * 10
|
||||
|
||||
def test_pool_apply_async_with_shared_data(self):
|
||||
"""测试在进程池 apply_async 中使用共享数据"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
pool.put("async_key", "async_value")
|
||||
|
||||
with multiprocessing.Pool(processes=2) as process_pool:
|
||||
result = process_pool.apply_async(_pool_worker_apply, ("async_key",))
|
||||
assert result.get(timeout=5) == "async_value"
|
||||
|
||||
|
||||
class TestDataVisibility:
|
||||
"""测试数据可见性和同步时机"""
|
||||
|
||||
def test_immediate_visibility_after_put(self):
|
||||
"""测试写入后立即可见"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
pool.put("immediate_key", "immediate_value")
|
||||
# 同一进程内应该立即可见
|
||||
assert pool.exists("immediate_key")
|
||||
assert pool.get("immediate_key") == "immediate_value"
|
||||
|
||||
def test_cross_process_visibility_with_delay(self):
|
||||
"""测试跨进程可见性(带延迟)"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 父进程写入
|
||||
pool.put("delayed_key", "delayed_value")
|
||||
|
||||
# 子进程延迟后读取
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_wait_and_get, args=("delayed_key", 0.2, result_queue)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
assert result == "delayed_value"
|
||||
|
||||
|
||||
class TestConcurrentModifications:
|
||||
"""测试并发修改场景"""
|
||||
|
||||
def test_concurrent_counter_increments(self):
|
||||
"""测试并发计数器递增(非原子操作,预期会有竞争条件)"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 初始化计数器
|
||||
pool.put("counter", 0)
|
||||
|
||||
num_processes = 4
|
||||
iterations_per_process = 10
|
||||
|
||||
processes = []
|
||||
for _ in range(num_processes):
|
||||
p = multiprocessing.Process(
|
||||
target=worker_increment_counter,
|
||||
args=("counter", iterations_per_process),
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
time.sleep(0.1)
|
||||
|
||||
# 由于竞争条件,实际值可能小于期望值
|
||||
# 这个测试主要是为了验证并发访问不会崩溃
|
||||
final_count = pool.get("counter")
|
||||
assert isinstance(final_count, int)
|
||||
assert 0 <= final_count <= num_processes * iterations_per_process
|
||||
|
||||
def test_concurrent_pop_operations(self):
|
||||
"""测试并发 pop 操作"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 初始化多个 key
|
||||
num_keys = 5
|
||||
for i in range(num_keys):
|
||||
pool.put(f"pop_key_{i}", f"pop_value_{i}")
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
processes = []
|
||||
|
||||
for i in range(num_keys):
|
||||
p = multiprocessing.Process(
|
||||
target=worker_pop_data, args=(f"pop_key_{i}", result_queue)
|
||||
)
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join()
|
||||
|
||||
# 收集所有 pop 的结果
|
||||
popped_values = []
|
||||
for _ in range(num_keys):
|
||||
popped_values.append(result_queue.get())
|
||||
|
||||
# 验证所有值都被正确 pop
|
||||
assert len(popped_values) == num_keys
|
||||
for i in range(num_keys):
|
||||
assert f"pop_value_{i}" in popped_values
|
||||
|
||||
# 验证所有 key 都被移除
|
||||
time.sleep(0.1)
|
||||
for i in range(num_keys):
|
||||
assert not pool.exists(f"pop_key_{i}")
|
||||
|
||||
|
||||
class TestComplexDataTypes:
|
||||
"""测试复杂数据类型的多进程共享"""
|
||||
|
||||
def test_share_nested_dict(self):
|
||||
"""测试共享嵌套字典"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
nested_data = {
|
||||
"level1": {"level2": {"level3": [1, 2, 3]}},
|
||||
"list_of_dicts": [{"a": 1}, {"b": 2}],
|
||||
}
|
||||
|
||||
pool.put("nested_key", nested_data)
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_get_data, args=("nested_key", result_queue)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
assert result == nested_data
|
||||
|
||||
def test_share_large_list(self):
|
||||
"""测试共享大型列表"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
large_list = list(range(10000))
|
||||
pool.put("large_list_key", large_list)
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_get_data, args=("large_list_key", result_queue)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
assert result == large_list
|
||||
569
test/test_numpy.py
Normal file
569
test/test_numpy.py
Normal file
@ -0,0 +1,569 @@
|
||||
"""
|
||||
NumPy ndarray 支持测试 - 验证数组数据共享
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
import time
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mpsp.mpsp import MultiProcessingSharedPool
|
||||
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
|
||||
def worker_get_array(key, result_queue):
|
||||
"""子进程:获取数组并放入结果队列"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
arr = pool.get(key)
|
||||
result_queue.put(arr)
|
||||
|
||||
|
||||
def worker_modify_array(key, index, value, result_queue):
|
||||
"""子进程:读取数组、修改特定索引、返回原值"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
arr = pool.get(key)
|
||||
old_value = arr[index].copy() if isinstance(index, tuple) else arr[index]
|
||||
# 注意:此处修改的是副本,因为 get 返回的是数组副本
|
||||
result_queue.put(old_value)
|
||||
|
||||
|
||||
def worker_sum_array(key, result_queue):
|
||||
"""子进程:计算数组元素和"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
arr = pool.get(key)
|
||||
result_queue.put(np.sum(arr))
|
||||
|
||||
|
||||
def worker_check_array_properties(key, expected_shape, expected_dtype, result_queue):
|
||||
"""子进程:检查数组属性"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
arr = pool.get(key)
|
||||
result_queue.put(
|
||||
(arr.shape == expected_shape, arr.dtype == expected_dtype, arr.ndim)
|
||||
)
|
||||
|
||||
|
||||
# ==================== 测试类 ====================
|
||||
|
||||
|
||||
class TestNDBasicOperations:
|
||||
"""测试 NumPy 数组基本操作"""
|
||||
|
||||
def test_1d_array(self):
|
||||
"""测试一维数组存取"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
arr = np.array([1, 2, 3, 4, 5])
|
||||
pool.put("1d_array", arr)
|
||||
|
||||
retrieved = pool.get("1d_array")
|
||||
np.testing.assert_array_equal(retrieved, arr)
|
||||
assert retrieved.dtype == arr.dtype
|
||||
|
||||
def test_2d_array(self):
|
||||
"""测试二维数组存取"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
pool.put("2d_array", arr)
|
||||
|
||||
retrieved = pool.get("2d_array")
|
||||
np.testing.assert_array_equal(retrieved, arr)
|
||||
assert retrieved.shape == (3, 3)
|
||||
|
||||
def test_3d_array(self):
|
||||
"""测试三维数组存取"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
arr = np.arange(24).reshape(2, 3, 4)
|
||||
pool.put("3d_array", arr)
|
||||
|
||||
retrieved = pool.get("3d_array")
|
||||
np.testing.assert_array_equal(retrieved, arr)
|
||||
assert retrieved.shape == (2, 3, 4)
|
||||
|
||||
def test_multidimensional_array(self):
|
||||
"""测试高维数组存取"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 4维数组
|
||||
arr = np.random.rand(2, 3, 4, 5)
|
||||
pool.put("4d_array", arr)
|
||||
|
||||
retrieved = pool.get("4d_array")
|
||||
np.testing.assert_array_almost_equal(retrieved, arr)
|
||||
assert retrieved.shape == (2, 3, 4, 5)
|
||||
|
||||
|
||||
class TestNDDataTypes:
|
||||
"""测试不同数据类型的 NumPy 数组"""
|
||||
|
||||
def test_integer_dtypes(self):
|
||||
"""测试整数类型数组"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
dtypes = [
|
||||
np.int8,
|
||||
np.int16,
|
||||
np.int32,
|
||||
np.int64,
|
||||
np.uint8,
|
||||
np.uint16,
|
||||
np.uint32,
|
||||
np.uint64,
|
||||
]
|
||||
|
||||
for dtype in dtypes:
|
||||
arr = np.array([1, 2, 3], dtype=dtype)
|
||||
key = f"int_array_{dtype.__name__}"
|
||||
pool.put(key, arr)
|
||||
|
||||
retrieved = pool.get(key)
|
||||
assert retrieved.dtype == dtype
|
||||
np.testing.assert_array_equal(retrieved, arr)
|
||||
|
||||
def test_float_dtypes(self):
|
||||
"""测试浮点类型数组"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
dtypes = [np.float32, np.float64]
|
||||
|
||||
for dtype in dtypes:
|
||||
arr = np.array([1.1, 2.2, 3.3], dtype=dtype)
|
||||
key = f"float_array_{dtype.__name__}"
|
||||
pool.put(key, arr)
|
||||
|
||||
retrieved = pool.get(key)
|
||||
assert retrieved.dtype == dtype
|
||||
np.testing.assert_array_almost_equal(retrieved, arr)
|
||||
|
||||
def test_complex_dtypes(self):
|
||||
"""测试复数类型数组"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
dtypes = [np.complex64, np.complex128]
|
||||
|
||||
for dtype in dtypes:
|
||||
arr = np.array([1 + 2j, 3 + 4j, 5 + 6j], dtype=dtype)
|
||||
key = f"complex_array_{dtype.__name__}"
|
||||
pool.put(key, arr)
|
||||
|
||||
retrieved = pool.get(key)
|
||||
assert retrieved.dtype == dtype
|
||||
np.testing.assert_array_almost_equal(retrieved, arr)
|
||||
|
||||
def test_bool_dtype(self):
|
||||
"""测试布尔类型数组"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
arr = np.array([True, False, True, True, False], dtype=np.bool_)
|
||||
pool.put("bool_array", arr)
|
||||
|
||||
retrieved = pool.get("bool_array")
|
||||
assert retrieved.dtype == np.bool_
|
||||
np.testing.assert_array_equal(retrieved, arr)
|
||||
|
||||
def test_string_dtype(self):
|
||||
"""测试字符串类型数组"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# Unicode 字符串
|
||||
arr = np.array(["hello", "world", "mpsp"])
|
||||
pool.put("string_array", arr)
|
||||
|
||||
retrieved = pool.get("string_array")
|
||||
np.testing.assert_array_equal(retrieved, arr)
|
||||
|
||||
def test_object_dtype(self):
|
||||
"""测试对象类型数组"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 对象数组可以存储不同类型的数据
|
||||
arr = np.array([1, "string", 3.14, [1, 2, 3]], dtype=object)
|
||||
pool.put("object_array", arr)
|
||||
|
||||
retrieved = pool.get("object_array")
|
||||
assert retrieved.dtype == object
|
||||
np.testing.assert_array_equal(retrieved, arr)
|
||||
|
||||
|
||||
class TestNDCrossProcess:
|
||||
"""测试 NumPy 数组跨进程共享"""
|
||||
|
||||
def test_1d_array_cross_process(self):
|
||||
"""测试一维数组跨进程传递"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
arr = np.array([10, 20, 30, 40, 50])
|
||||
pool.put("shared_1d", arr)
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_get_array, args=("shared_1d", result_queue)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
np.testing.assert_array_equal(result, arr)
|
||||
|
||||
def test_2d_array_cross_process(self):
|
||||
"""测试二维数组跨进程传递"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
arr = np.array([[1, 2], [3, 4], [5, 6]])
|
||||
pool.put("shared_2d", arr)
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_get_array, args=("shared_2d", result_queue)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
np.testing.assert_array_equal(result, arr)
|
||||
|
||||
def test_array_properties_cross_process(self):
|
||||
"""测试数组属性跨进程保持"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
arr = np.arange(12).reshape(3, 4).astype(np.float32)
|
||||
pool.put("property_test", arr)
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_check_array_properties,
|
||||
args=("property_test", (3, 4), np.float32, result_queue),
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
shape_match, dtype_match, ndim = result_queue.get()
|
||||
assert shape_match
|
||||
assert dtype_match
|
||||
assert ndim == 2
|
||||
|
||||
def test_array_operations_cross_process(self):
|
||||
"""测试在子进程中执行数组操作"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
arr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
||||
pool.put("sum_test", arr)
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_sum_array, args=("sum_test", result_queue)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
assert result == 55 # sum of 1..10
|
||||
|
||||
|
||||
class TestNDSpecialArrays:
|
||||
"""测试特殊类型的 NumPy 数组"""
|
||||
|
||||
def test_empty_array(self):
|
||||
"""测试空数组"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
arr = np.array([])
|
||||
pool.put("empty_array", arr)
|
||||
|
||||
retrieved = pool.get("empty_array")
|
||||
np.testing.assert_array_equal(retrieved, arr)
|
||||
assert len(retrieved) == 0
|
||||
|
||||
def test_single_element_array(self):
|
||||
"""测试单元素数组"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
arr = np.array([42])
|
||||
pool.put("single_element", arr)
|
||||
|
||||
retrieved = pool.get("single_element")
|
||||
np.testing.assert_array_equal(retrieved, arr)
|
||||
assert retrieved[0] == 42
|
||||
|
||||
def test_zeros_array(self):
|
||||
"""测试零数组"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
arr = np.zeros((5, 5))
|
||||
pool.put("zeros_array", arr)
|
||||
|
||||
retrieved = pool.get("zeros_array")
|
||||
np.testing.assert_array_equal(retrieved, arr)
|
||||
assert np.all(retrieved == 0)
|
||||
|
||||
def test_ones_array(self):
|
||||
"""测试全一数组"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
arr = np.ones((3, 4))
|
||||
pool.put("ones_array", arr)
|
||||
|
||||
retrieved = pool.get("ones_array")
|
||||
np.testing.assert_array_equal(retrieved, arr)
|
||||
assert np.all(retrieved == 1)
|
||||
|
||||
def test_eye_array(self):
|
||||
"""测试单位矩阵"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
arr = np.eye(5)
|
||||
pool.put("eye_array", arr)
|
||||
|
||||
retrieved = pool.get("eye_array")
|
||||
np.testing.assert_array_equal(retrieved, arr)
|
||||
assert np.all(np.diag(retrieved) == 1)
|
||||
|
||||
def test_nan_and_inf_array(self):
|
||||
"""测试包含 NaN 和 Inf 的数组"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
arr = np.array([1.0, np.nan, np.inf, -np.inf, 2.0])
|
||||
pool.put("special_values", arr)
|
||||
|
||||
retrieved = pool.get("special_values")
|
||||
assert np.isnan(retrieved[1])
|
||||
assert np.isinf(retrieved[2]) and retrieved[2] > 0
|
||||
assert np.isinf(retrieved[3]) and retrieved[3] < 0
|
||||
|
||||
def test_masked_array(self):
|
||||
"""测试掩码数组"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
data = np.array([1, 2, 3, 4, 5])
|
||||
mask = np.array([False, True, False, True, False])
|
||||
arr = np.ma.array(data, mask=mask)
|
||||
|
||||
pool.put("masked_array", arr)
|
||||
|
||||
retrieved = pool.get("masked_array")
|
||||
np.testing.assert_array_equal(retrieved.data, data)
|
||||
np.testing.assert_array_equal(retrieved.mask, mask)
|
||||
|
||||
|
||||
class TestNDLargeArrays:
|
||||
"""测试大型 NumPy 数组"""
|
||||
|
||||
def test_large_1d_array(self):
|
||||
"""测试大型一维数组"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 10000 个元素的数组
|
||||
arr = np.arange(10000)
|
||||
pool.put("large_1d", arr)
|
||||
|
||||
retrieved = pool.get("large_1d")
|
||||
np.testing.assert_array_equal(retrieved, arr)
|
||||
|
||||
def test_large_2d_array(self):
|
||||
"""测试大型二维数组"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 1000x100 的数组
|
||||
arr = np.random.rand(1000, 100)
|
||||
pool.put("large_2d", arr)
|
||||
|
||||
retrieved = pool.get("large_2d")
|
||||
np.testing.assert_array_almost_equal(retrieved, arr)
|
||||
|
||||
def test_large_array_cross_process(self):
|
||||
"""测试大型数组跨进程传递"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
arr = np.arange(100000).reshape(1000, 100)
|
||||
pool.put("large_cross", arr)
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_sum_array, args=("large_cross", result_queue)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
expected_sum = np.sum(arr)
|
||||
assert result == expected_sum
|
||||
|
||||
|
||||
class TestNDStructuredArrays:
|
||||
"""测试结构化数组"""
|
||||
|
||||
def test_structured_array(self):
|
||||
"""测试结构化数组存取"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
dt = np.dtype([("name", "U10"), ("age", "i4"), ("weight", "f4")])
|
||||
arr = np.array(
|
||||
[("Alice", 25, 55.5), ("Bob", 30, 85.3), ("Charlie", 35, 75.0)], dtype=dt
|
||||
)
|
||||
|
||||
pool.put("structured_array", arr)
|
||||
|
||||
retrieved = pool.get("structured_array")
|
||||
assert retrieved.dtype == dt
|
||||
np.testing.assert_array_equal(retrieved, arr)
|
||||
|
||||
def test_structured_array_cross_process(self):
|
||||
"""测试结构化数组跨进程传递"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
dt = np.dtype([("x", "f4"), ("y", "f4"), ("z", "f4")])
|
||||
arr = np.array([(1.0, 2.0, 3.0), (4.0, 5.0, 6.0)], dtype=dt)
|
||||
|
||||
pool.put("structured_cross", arr)
|
||||
|
||||
result_queue = multiprocessing.Queue()
|
||||
p = multiprocessing.Process(
|
||||
target=worker_get_array, args=("structured_cross", result_queue)
|
||||
)
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
result = result_queue.get()
|
||||
np.testing.assert_array_equal(result, arr)
|
||||
|
||||
|
||||
class TestNDMatrixOperations:
|
||||
"""测试矩阵操作相关的数组"""
|
||||
|
||||
def test_matrix_multiplication(self):
|
||||
"""测试矩阵乘法用的数组"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
A = np.array([[1, 2], [3, 4]])
|
||||
B = np.array([[5, 6], [7, 8]])
|
||||
|
||||
pool.put("matrix_A", A)
|
||||
pool.put("matrix_B", B)
|
||||
|
||||
retrieved_A = pool.get("matrix_A")
|
||||
retrieved_B = pool.get("matrix_B")
|
||||
|
||||
result = np.dot(retrieved_A, retrieved_B)
|
||||
expected = np.array([[19, 22], [43, 50]])
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
def test_eigenvalue_computation(self):
|
||||
"""测试特征值计算"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
# 对称矩阵
|
||||
arr = np.array([[4, 2], [2, 4]])
|
||||
pool.put("eigen_matrix", arr)
|
||||
|
||||
retrieved = pool.get("eigen_matrix")
|
||||
eigenvalues, eigenvectors = np.linalg.eig(retrieved)
|
||||
|
||||
# 特征值应该是 6 和 2
|
||||
assert np.allclose(sorted(eigenvalues), [2, 6])
|
||||
|
||||
def test_svd_decomposition(self):
|
||||
"""测试 SVD 分解"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
arr = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
||||
pool.put("svd_matrix", arr)
|
||||
|
||||
retrieved = pool.get("svd_matrix")
|
||||
U, S, Vh = np.linalg.svd(retrieved)
|
||||
|
||||
# 验证分解结果
|
||||
reconstructed = U @ np.diag(S) @ Vh
|
||||
np.testing.assert_array_almost_equal(reconstructed, arr)
|
||||
|
||||
|
||||
class TestNDBroadcasting:
|
||||
"""测试广播机制"""
|
||||
|
||||
def test_broadcasting_operation(self):
|
||||
"""测试广播操作"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
arr_2d = np.array([[1, 2, 3], [4, 5, 6]])
|
||||
arr_1d = np.array([10, 20, 30])
|
||||
|
||||
pool.put("array_2d", arr_2d)
|
||||
pool.put("array_1d", arr_1d)
|
||||
|
||||
retrieved_2d = pool.get("array_2d")
|
||||
retrieved_1d = pool.get("array_1d")
|
||||
|
||||
result = retrieved_2d + retrieved_1d
|
||||
expected = np.array([[11, 22, 33], [14, 25, 36]])
|
||||
np.testing.assert_array_equal(result, expected)
|
||||
|
||||
|
||||
class TestNDMixedTypes:
|
||||
"""测试混合数据类型的数组相关操作"""
|
||||
|
||||
def test_array_in_dict(self):
|
||||
"""测试字典中包含数组"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
data = {
|
||||
"matrix": np.array([[1, 2], [3, 4]]),
|
||||
"vector": np.array([1, 2, 3]),
|
||||
"scalar": 42,
|
||||
"name": "test",
|
||||
}
|
||||
|
||||
pool.put("dict_with_arrays", data)
|
||||
|
||||
retrieved = pool.get("dict_with_arrays")
|
||||
np.testing.assert_array_equal(retrieved["matrix"], data["matrix"])
|
||||
np.testing.assert_array_equal(retrieved["vector"], data["vector"])
|
||||
assert retrieved["scalar"] == 42
|
||||
assert retrieved["name"] == "test"
|
||||
|
||||
def test_array_in_list(self):
|
||||
"""测试列表中包含数组"""
|
||||
pool = MultiProcessingSharedPool()
|
||||
pool.clear()
|
||||
|
||||
data = [np.array([1, 2, 3]), np.array([[4, 5], [6, 7]]), "string", 42]
|
||||
|
||||
pool.put("list_with_arrays", data)
|
||||
|
||||
retrieved = pool.get("list_with_arrays")
|
||||
np.testing.assert_array_equal(retrieved[0], data[0])
|
||||
np.testing.assert_array_equal(retrieved[1], data[1])
|
||||
assert retrieved[2] == "string"
|
||||
assert retrieved[3] == 42
|
||||
Reference in New Issue
Block a user