325 lines
8.9 KiB
Python
325 lines
8.9 KiB
Python
"""
|
||
基础功能测试 - 验证核心 API 正确性
|
||
"""
|
||
|
||
import pytest
|
||
from mpsp.mpsp import MultiProcessingSharedPool
|
||
|
||
|
||
class TestBasicOperations:
|
||
"""测试基本操作"""
|
||
|
||
def test_singleton_pattern(self):
|
||
"""测试单例模式 - 多次获取应为同一实例"""
|
||
pool1 = MultiProcessingSharedPool()
|
||
pool2 = MultiProcessingSharedPool.get_instance()
|
||
pool3 = MultiProcessingSharedPool()
|
||
|
||
assert pool1 is pool2
|
||
assert pool2 is pool3
|
||
assert pool1 is pool3
|
||
|
||
def test_put_and_get_basic_types(self):
|
||
"""测试基础类型数据的存取"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
# 整数
|
||
pool.put("int_key", 42)
|
||
assert pool.get("int_key") == 42
|
||
|
||
# 浮点数
|
||
pool.put("float_key", 3.14159)
|
||
assert abs(pool.get("float_key") - 3.14159) < 1e-10
|
||
|
||
# 字符串
|
||
pool.put("str_key", "hello mpsp")
|
||
assert pool.get("str_key") == "hello mpsp"
|
||
|
||
# 布尔值
|
||
pool.put("bool_key", True)
|
||
assert pool.get("bool_key") is True
|
||
|
||
# None
|
||
pool.put("none_key", None)
|
||
assert pool.get("none_key") is None
|
||
|
||
def test_put_and_get_collections(self):
|
||
"""测试集合类型数据的存取"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
# 列表
|
||
test_list = [1, 2, 3, "a", "b", "c"]
|
||
pool.put("list_key", test_list)
|
||
assert pool.get("list_key") == test_list
|
||
|
||
# 字典
|
||
test_dict = {"name": "test", "value": 100, "nested": {"a": 1}}
|
||
pool.put("dict_key", test_dict)
|
||
assert pool.get("dict_key") == test_dict
|
||
|
||
# 元组
|
||
test_tuple = (1, 2, 3)
|
||
pool.put("tuple_key", test_tuple)
|
||
assert pool.get("tuple_key") == test_tuple
|
||
|
||
# 集合
|
||
test_set = {1, 2, 3}
|
||
pool.put("set_key", test_set)
|
||
assert pool.get("set_key") == test_set
|
||
|
||
def test_exists(self):
|
||
"""测试 exists 方法"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
assert not pool.exists("nonexistent_key")
|
||
|
||
pool.put("existing_key", "value")
|
||
assert pool.exists("existing_key")
|
||
|
||
pool.remove("existing_key")
|
||
assert not pool.exists("existing_key")
|
||
|
||
def test_remove(self):
|
||
"""测试 remove 方法"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
# 移除存在的 key
|
||
pool.put("key_to_remove", "value")
|
||
assert pool.remove("key_to_remove") is True
|
||
assert not pool.exists("key_to_remove")
|
||
|
||
# 移除不存在的 key
|
||
assert pool.remove("nonexistent_key") is False
|
||
|
||
def test_pop(self):
|
||
"""测试 pop 方法"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
# pop 存在的 key
|
||
pool.put("key_to_pop", "popped_value")
|
||
value = pool.pop("key_to_pop")
|
||
assert value == "popped_value"
|
||
assert not pool.exists("key_to_pop")
|
||
|
||
# pop 不存在的 key(带默认值)
|
||
default_value = pool.pop("nonexistent_key", "default")
|
||
assert default_value == "default"
|
||
|
||
# pop 不存在的 key(不带默认值)
|
||
assert pool.pop("nonexistent_key") is None
|
||
|
||
def test_size(self):
|
||
"""测试 size 方法"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
assert pool.size() == 0
|
||
|
||
pool.put("key1", "value1")
|
||
assert pool.size() == 1
|
||
|
||
pool.put("key2", "value2")
|
||
pool.put("key3", "value3")
|
||
assert pool.size() == 3
|
||
|
||
pool.remove("key1")
|
||
assert pool.size() == 2
|
||
|
||
pool.clear()
|
||
assert pool.size() == 0
|
||
|
||
def test_keys(self):
|
||
"""测试 keys 方法"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
assert pool.keys() == []
|
||
|
||
pool.put("key1", "value1")
|
||
pool.put("key2", "value2")
|
||
pool.put("key3", "value3")
|
||
|
||
keys = pool.keys()
|
||
assert len(keys) == 3
|
||
assert set(keys) == {"key1", "key2", "key3"}
|
||
|
||
def test_clear(self):
|
||
"""测试 clear 方法"""
|
||
pool = MultiProcessingSharedPool()
|
||
|
||
pool.put("key1", "value1")
|
||
pool.put("key2", "value2")
|
||
|
||
pool.clear()
|
||
|
||
assert pool.size() == 0
|
||
assert pool.keys() == []
|
||
assert not pool.exists("key1")
|
||
|
||
|
||
class TestDictInterface:
|
||
"""测试字典风格接口"""
|
||
|
||
def test_getitem(self):
|
||
"""测试 __getitem__"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
pool.put("test_key", "test_value")
|
||
assert pool["test_key"] == "test_value"
|
||
|
||
# 访问不存在的 key 应抛出 KeyError
|
||
with pytest.raises(KeyError):
|
||
_ = pool["nonexistent_key"]
|
||
|
||
def test_setitem(self):
|
||
"""测试 __setitem__"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
pool["new_key"] = "new_value"
|
||
assert pool.get("new_key") == "new_value"
|
||
|
||
def test_delitem(self):
|
||
"""测试 __delitem__"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
pool.put("key_to_delete", "value")
|
||
del pool["key_to_delete"]
|
||
assert not pool.exists("key_to_delete")
|
||
|
||
# 删除不存在的 key 应抛出 KeyError
|
||
with pytest.raises(KeyError):
|
||
del pool["nonexistent_key"]
|
||
|
||
def test_contains(self):
|
||
"""测试 __contains__ (in 操作符)"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
pool.put("existing_key", "value")
|
||
|
||
assert "existing_key" in pool
|
||
assert "nonexistent_key" not in pool
|
||
|
||
|
||
class TestContextManager:
|
||
"""测试上下文管理器"""
|
||
|
||
def test_context_manager(self):
|
||
"""测试 with 语句支持"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
with MultiProcessingSharedPool() as p:
|
||
p.put("ctx_key", "ctx_value")
|
||
assert p.get("ctx_key") == "ctx_value"
|
||
|
||
# 上下文退出后数据应该仍然存在(管理器未关闭)
|
||
assert pool.get("ctx_key") == "ctx_value"
|
||
|
||
|
||
class TestErrorHandling:
|
||
"""测试错误处理"""
|
||
|
||
def test_invalid_label_type_put(self):
|
||
"""测试 put 时传入非法 label 类型"""
|
||
pool = MultiProcessingSharedPool()
|
||
|
||
with pytest.raises(TypeError, match="Label must be a string"):
|
||
pool.put(123, "value")
|
||
|
||
with pytest.raises(TypeError, match="Label must be a string"):
|
||
pool.put(None, "value")
|
||
|
||
def test_invalid_label_type_get(self):
|
||
"""测试 get 时传入非法 label 类型"""
|
||
pool = MultiProcessingSharedPool()
|
||
|
||
with pytest.raises(TypeError, match="Label must be a string"):
|
||
pool.get(123)
|
||
|
||
def test_invalid_label_type_pop(self):
|
||
"""测试 pop 时传入非法 label 类型"""
|
||
pool = MultiProcessingSharedPool()
|
||
|
||
with pytest.raises(TypeError, match="Label must be a string"):
|
||
pool.pop(123)
|
||
|
||
def test_get_with_default(self):
|
||
"""测试 get 带默认值的场景"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
# key 不存在时返回默认值
|
||
assert pool.get("nonexistent", "default") == "default"
|
||
assert pool.get("nonexistent", None) is None
|
||
assert pool.get("nonexistent") is None
|
||
|
||
# key 存在时返回实际值
|
||
pool.put("existing", "real_value")
|
||
assert pool.get("existing", "default") == "real_value"
|
||
|
||
def test_special_label_names(self):
|
||
"""测试特殊 label 名称"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
# 空字符串
|
||
pool.put("", "empty_string_key")
|
||
assert pool.get("") == "empty_string_key"
|
||
|
||
# 特殊字符
|
||
special_keys = [
|
||
"key with spaces",
|
||
"key\twith\ttabs",
|
||
"key\nwith\nnewlines",
|
||
"key/with/slashes",
|
||
"key.with.dots",
|
||
"key:with:colons",
|
||
"UPPERCASE_KEY",
|
||
"mixedCase_Key",
|
||
"unicode_中文_key",
|
||
"emoji_😀_key",
|
||
]
|
||
|
||
for key in special_keys:
|
||
pool.put(key, f"value_for_{key}")
|
||
assert pool.get(key) == f"value_for_{key}"
|
||
|
||
|
||
class TestOverwrite:
|
||
"""测试覆盖写入"""
|
||
|
||
def test_overwrite_value(self):
|
||
"""测试覆盖相同 key 的值"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
pool.put("key", "original_value")
|
||
assert pool.get("key") == "original_value"
|
||
|
||
pool.put("key", "new_value")
|
||
assert pool.get("key") == "new_value"
|
||
|
||
# 不同类型覆盖
|
||
pool.put("key", 12345)
|
||
assert pool.get("key") == 12345
|
||
|
||
def test_overwrite_with_none(self):
|
||
"""测试用 None 覆盖有值的 key"""
|
||
pool = MultiProcessingSharedPool()
|
||
pool.clear()
|
||
|
||
pool.put("key", "value")
|
||
pool.put("key", None)
|
||
|
||
assert pool.get("key") is None
|
||
assert pool.exists("key") # key 应该仍然存在
|