Initial commit.

This commit is contained in:
2026-02-21 12:00:47 +08:00
commit cd335c1b3f
14 changed files with 3492 additions and 0 deletions

652
test/test_edge_cases.py Normal file
View File

@ -0,0 +1,652 @@
"""
边界与异常测试 - 验证鲁棒性
"""
import multiprocessing
import time
import pickle
import pytest
from mpsp.mpsp import MultiProcessingSharedPool
# ==================== 辅助函数 ====================
def worker_put_empty_key(result_queue):
"""子进程:测试空字符串 key"""
pool = MultiProcessingSharedPool()
try:
pool.put("", "empty_value")
result_queue.put(("success", pool.get("")))
except Exception as e:
result_queue.put(("error", str(e)))
def worker_get_nonexistent(result_queue):
"""子进程:获取不存在的 key"""
pool = MultiProcessingSharedPool()
result = pool.get("definitely_nonexistent_key_12345")
result_queue.put(result)
def worker_put_large_object(key, data, result_queue):
"""子进程:存储大对象"""
pool = MultiProcessingSharedPool()
try:
success = pool.put(key, data)
result_queue.put(("success", success))
except Exception as e:
result_queue.put(("error", str(e)))
# ==================== 测试类 ====================
class TestEmptyAndNoneValues:
"""测试空值和 None 处理"""
def test_put_empty_string_value(self):
"""测试存储空字符串值"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("empty_value_key", "")
assert pool.get("empty_value_key") == ""
def test_put_none_value(self):
"""测试存储 None 值"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("none_value_key", None)
assert pool.get("none_value_key") is None
def test_put_empty_list(self):
"""测试存储空列表"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("empty_list", [])
assert pool.get("empty_list") == []
def test_put_empty_dict(self):
"""测试存储空字典"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("empty_dict", {})
assert pool.get("empty_dict") == {}
def test_put_empty_tuple(self):
"""测试存储空元组"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("empty_tuple", ())
assert pool.get("empty_tuple") == ()
def test_put_empty_set(self):
"""测试存储空集合"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("empty_set", set())
assert pool.get("empty_set") == set()
class TestSpecialLabelNames:
"""测试特殊 label 名称"""
def test_empty_string_label(self):
"""测试空字符串作为 label"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("", "empty_key_value")
assert pool.get("") == "empty_key_value"
assert pool.exists("")
def test_empty_string_label_cross_process(self):
"""测试空字符串 label 跨进程"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("", "parent_empty_value")
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(target=worker_put_empty_key, args=(result_queue,))
p.start()
p.join()
status, result = result_queue.get()
# 子进程应该可以覆盖空字符串 key
assert status == "success"
assert result == "empty_value"
def test_unicode_label(self):
"""测试 Unicode 字符作为 label"""
pool = MultiProcessingSharedPool()
pool.clear()
unicode_keys = [
"中文键",
"日本語キー",
"한국어키",
"emoji_😀",
"special_©_®_™",
"math_∑_∏_√",
"arrows_→_←_↑_↓",
]
for key in unicode_keys:
pool.put(key, f"value_for_{key}")
for key in unicode_keys:
assert pool.get(key) == f"value_for_{key}"
def test_whitespace_label(self):
"""测试空白字符作为 label"""
pool = MultiProcessingSharedPool()
pool.clear()
whitespace_keys = [
" ", # 单个空格
" ", # 两个空格
"\t", # Tab
"\n", # 换行
"\r\n", # Windows 换行
" key_with_leading_space",
"key_with_trailing_space ",
" key_with_both_spaces ",
"key\twith\ttabs",
]
for key in whitespace_keys:
pool.put(key, f"value_for_repr_{repr(key)}")
for key in whitespace_keys:
assert pool.get(key) == f"value_for_repr_{repr(key)}"
def test_special_chars_label(self):
"""测试特殊字符作为 label"""
pool = MultiProcessingSharedPool()
pool.clear()
special_keys = [
"key.with.dots",
"key/with/slashes",
"key:with:colons",
"key|with|pipes",
"key*with*asterisks",
"key?with?question",
"key<with>brackets",
"key[with]square",
"key{with}curly",
"key+with+plus",
"key=with=equals",
"key!with!exclamation",
"key@with@at",
"key#with#hash",
"key$with$dollar",
"key%with%percent",
"key^with^caret",
"key&with&ampersand",
"key'with'quotes",
'key"with"double',
"key`with`backtick",
"key~with~tilde",
"key-with-hyphens",
"key_with_underscores",
]
for key in special_keys:
pool.put(key, f"value_for_{key}")
for key in special_keys:
assert pool.get(key) == f"value_for_{key}"
def test_very_long_label(self):
"""测试超长 label"""
pool = MultiProcessingSharedPool()
pool.clear()
# 1000 字符的 label
long_key = "a" * 1000
pool.put(long_key, "long_key_value")
assert pool.get(long_key) == "long_key_value"
class TestNonExistentKeys:
"""测试不存在的 key 处理"""
def test_get_nonexistent(self):
"""测试获取不存在的 key"""
pool = MultiProcessingSharedPool()
pool.clear()
result = pool.get("nonexistent_key_12345")
assert result is None
def test_get_nonexistent_with_default(self):
"""测试获取不存在的 key 带默认值"""
pool = MultiProcessingSharedPool()
pool.clear()
assert pool.get("nonexistent", "default") == "default"
assert pool.get("nonexistent", 0) == 0
assert pool.get("nonexistent", []) == []
assert pool.get("nonexistent", {}) == {}
def test_get_nonexistent_cross_process(self):
"""测试跨进程获取不存在的 key"""
pool = MultiProcessingSharedPool()
pool.clear()
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(target=worker_get_nonexistent, args=(result_queue,))
p.start()
p.join()
result = result_queue.get()
assert result is None
def test_remove_nonexistent(self):
"""测试删除不存在的 key"""
pool = MultiProcessingSharedPool()
pool.clear()
assert pool.remove("nonexistent_key") is False
def test_pop_nonexistent(self):
"""测试弹出不存在的 key"""
pool = MultiProcessingSharedPool()
pool.clear()
assert pool.pop("nonexistent_key") is None
assert pool.pop("nonexistent_key", "default") == "default"
def test_exists_nonexistent(self):
"""测试检查不存在的 key"""
pool = MultiProcessingSharedPool()
pool.clear()
assert pool.exists("nonexistent_key") is False
class TestLargeObjects:
"""测试大对象序列化"""
def test_large_list(self):
"""测试大型列表"""
pool = MultiProcessingSharedPool()
pool.clear()
large_list = list(range(100000))
pool.put("large_list", large_list)
retrieved = pool.get("large_list")
assert len(retrieved) == 100000
assert retrieved[0] == 0
assert retrieved[99999] == 99999
def test_large_dict(self):
"""测试大型字典"""
pool = MultiProcessingSharedPool()
pool.clear()
large_dict = {f"key_{i}": f"value_{i}" for i in range(10000)}
pool.put("large_dict", large_dict)
retrieved = pool.get("large_dict")
assert len(retrieved) == 10000
assert retrieved["key_0"] == "value_0"
assert retrieved["key_9999"] == "value_9999"
def test_large_string(self):
"""测试大型字符串"""
pool = MultiProcessingSharedPool()
pool.clear()
large_string = "x" * 1000000 # 1MB 字符串
pool.put("large_string", large_string)
retrieved = pool.get("large_string")
assert len(retrieved) == 1000000
assert retrieved[0] == "x"
assert retrieved[-1] == "x"
def test_deeply_nested_structure(self):
"""测试深度嵌套结构"""
pool = MultiProcessingSharedPool()
pool.clear()
# 创建深度嵌套的字典
depth = 50
nested = "bottom"
for i in range(depth):
nested = {"level": i, "nested": nested}
pool.put("deep_nested", nested)
retrieved = pool.get("deep_nested")
# 验证嵌套深度
current = retrieved
for i in range(depth):
assert current["level"] == depth - 1 - i
current = current["nested"]
assert current == "bottom"
def test_large_object_cross_process(self):
"""测试跨进程传递大对象"""
pool = MultiProcessingSharedPool()
pool.clear()
large_data = {"items": list(range(10000)), "name": "large_test"}
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_put_large_object,
args=("large_cross", large_data, result_queue),
)
p.start()
p.join()
status, result = result_queue.get()
assert status == "success"
assert result is True
class TestCircularReferences:
"""测试循环引用"""
def test_circular_list(self):
"""测试列表中的循环引用"""
pool = MultiProcessingSharedPool()
pool.clear()
# 创建循环引用列表
circular = [1, 2, 3]
circular.append(circular) # 循环引用
pool.put("circular_list", circular)
retrieved = pool.get("circular_list")
assert retrieved[0] == 1
assert retrieved[1] == 2
assert retrieved[2] == 3
# 循环引用应该被正确处理
assert retrieved[3] is not None
def test_circular_dict(self):
"""测试字典中的循环引用"""
pool = MultiProcessingSharedPool()
pool.clear()
# 创建循环引用字典
circular = {"a": 1, "b": 2}
circular["self"] = circular # 循环引用
pool.put("circular_dict", circular)
retrieved = pool.get("circular_dict")
assert retrieved["a"] == 1
assert retrieved["b"] == 2
# 循环引用应该被正确处理
assert "self" in retrieved
class TestBinaryData:
"""测试二进制数据"""
def test_bytes_data(self):
"""测试字节数据"""
pool = MultiProcessingSharedPool()
pool.clear()
binary_data = b"\x00\x01\x02\x03\xff\xfe\xfd\xfc"
pool.put("binary_data", binary_data)
retrieved = pool.get("binary_data")
assert retrieved == binary_data
def test_large_binary_data(self):
"""测试大型二进制数据"""
pool = MultiProcessingSharedPool()
pool.clear()
binary_data = bytes(range(256)) * 1000 # 256KB
pool.put("large_binary", binary_data)
retrieved = pool.get("large_binary")
assert retrieved == binary_data
def test_bytearray_data(self):
"""测试 bytearray 数据"""
pool = MultiProcessingSharedPool()
pool.clear()
ba = bytearray(b"\x00\x01\x02\x03")
pool.put("bytearray_data", ba)
retrieved = pool.get("bytearray_data")
assert retrieved == ba
class TestMixedTypes:
"""测试混合类型数据"""
def test_heterogeneous_list(self):
"""测试异构列表"""
pool = MultiProcessingSharedPool()
pool.clear()
mixed_list = [
1, # int
3.14, # float
"string", # str
True, # bool
None, # NoneType
[1, 2, 3], # list
{"a": 1}, # dict
(1, 2), # tuple
{1, 2, 3}, # set
b"binary", # bytes
]
pool.put("mixed_list", mixed_list)
retrieved = pool.get("mixed_list")
assert retrieved[0] == 1
assert abs(retrieved[1] - 3.14) < 1e-10
assert retrieved[2] == "string"
assert retrieved[3] is True
assert retrieved[4] is None
assert retrieved[5] == [1, 2, 3]
assert retrieved[6] == {"a": 1}
assert retrieved[7] == (1, 2)
assert retrieved[8] == {1, 2, 3}
assert retrieved[9] == b"binary"
def test_heterogeneous_dict(self):
"""测试异构字典"""
pool = MultiProcessingSharedPool()
pool.clear()
mixed_dict = {
"int_key": 42,
"float_key": 3.14,
"str_key": "hello",
"bool_key": True,
"none_key": None,
"list_key": [1, 2, 3],
"dict_key": {"nested": "value"},
"tuple_key": (1, 2, 3),
}
pool.put("mixed_dict", mixed_dict)
retrieved = pool.get("mixed_dict")
for key, value in mixed_dict.items():
if isinstance(value, float):
assert abs(retrieved[key] - value) < 1e-10
else:
assert retrieved[key] == value
class TestConcurrentAccess:
"""测试并发访问稳定性"""
def worker_stress_test(key_prefix, iterations, result_queue):
"""子进程:压力测试"""
pool = MultiProcessingSharedPool()
errors = []
for i in range(iterations):
try:
key = f"{key_prefix}_{i}"
pool.put(key, f"value_{i}")
value = pool.get(key)
if value != f"value_{i}":
errors.append(f"Value mismatch at {key}")
pool.remove(key)
except Exception as e:
errors.append(str(e))
result_queue.put(errors)
class TestConcurrentAccess:
"""测试并发访问稳定性"""
def test_stress_concurrent_writes(self):
"""压力测试:并发写入"""
pool = MultiProcessingSharedPool()
pool.clear()
num_processes = 4
iterations = 100
result_queue = multiprocessing.Queue()
processes = []
for i in range(num_processes):
p = multiprocessing.Process(
target=worker_stress_test,
args=(f"stress_{i}", iterations, result_queue),
)
processes.append(p)
p.start()
for p in processes:
p.join()
# 收集所有错误
all_errors = []
for _ in range(num_processes):
all_errors.extend(result_queue.get())
# 应该没有错误
assert len(all_errors) == 0, f"Errors occurred: {all_errors}"
def test_rapid_put_get_cycle(self):
"""测试快速 put-get 循环"""
pool = MultiProcessingSharedPool()
pool.clear()
for i in range(1000):
pool.put("rapid_key", f"value_{i}")
value = pool.get("rapid_key")
assert value == f"value_{i}"
def test_rapid_key_creation_deletion(self):
"""测试快速创建和删除 key"""
pool = MultiProcessingSharedPool()
pool.clear()
for i in range(100):
key = f"temp_key_{i}"
pool.put(key, f"temp_value_{i}")
assert pool.exists(key)
pool.remove(key)
assert not pool.exists(key)
class TestErrorRecovery:
"""测试错误恢复能力"""
def test_put_after_error(self):
"""测试错误后可以继续 put"""
pool = MultiProcessingSharedPool()
pool.clear()
# 尝试使用非法 key 类型
try:
pool.put(123, "value") # 应该抛出 TypeError
except TypeError:
pass
# 应该可以继续正常使用
pool.put("valid_key", "valid_value")
assert pool.get("valid_key") == "valid_value"
def test_get_after_nonexistent(self):
"""测试获取不存在的 key 后可以继续使用"""
pool = MultiProcessingSharedPool()
pool.clear()
# 获取不存在的 key
result = pool.get("nonexistent")
assert result is None
# 应该可以继续正常使用
pool.put("new_key", "new_value")
assert pool.get("new_key") == "new_value"
def test_multiple_singleton_access(self):
"""测试多次获取单例后访问"""
pool1 = MultiProcessingSharedPool()
pool1.put("key1", "value1")
pool2 = MultiProcessingSharedPool()
pool2.put("key2", "value2")
pool3 = MultiProcessingSharedPool.get_instance()
pool3.put("key3", "value3")
# 所有实例应该看到相同的数据
assert pool1.get("key1") == "value1"
assert pool1.get("key2") == "value2"
assert pool1.get("key3") == "value3"
class TestCleanup:
"""测试清理功能"""
def test_clear_after_multiple_puts(self):
"""测试多次 put 后 clear"""
pool = MultiProcessingSharedPool()
pool.clear()
for i in range(100):
pool.put(f"key_{i}", f"value_{i}")
assert pool.size() == 100
pool.clear()
assert pool.size() == 0
assert pool.keys() == []
def test_remove_all_keys_one_by_one(self):
"""测试逐个删除所有 key"""
pool = MultiProcessingSharedPool()
pool.clear()
keys = [f"key_{i}" for i in range(50)]
for key in keys:
pool.put(key, f"value_for_{key}")
for key in keys:
assert pool.remove(key) is True
assert pool.size() == 0