Files
mpsp/test/test_basic.py
2026-02-21 12:00:47 +08:00

325 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
基础功能测试 - 验证核心 API 正确性
"""
import pytest
from mpsp.mpsp import MultiProcessingSharedPool
class TestBasicOperations:
"""测试基本操作"""
def test_singleton_pattern(self):
"""测试单例模式 - 多次获取应为同一实例"""
pool1 = MultiProcessingSharedPool()
pool2 = MultiProcessingSharedPool.get_instance()
pool3 = MultiProcessingSharedPool()
assert pool1 is pool2
assert pool2 is pool3
assert pool1 is pool3
def test_put_and_get_basic_types(self):
"""测试基础类型数据的存取"""
pool = MultiProcessingSharedPool()
pool.clear()
# 整数
pool.put("int_key", 42)
assert pool.get("int_key") == 42
# 浮点数
pool.put("float_key", 3.14159)
assert abs(pool.get("float_key") - 3.14159) < 1e-10
# 字符串
pool.put("str_key", "hello mpsp")
assert pool.get("str_key") == "hello mpsp"
# 布尔值
pool.put("bool_key", True)
assert pool.get("bool_key") is True
# None
pool.put("none_key", None)
assert pool.get("none_key") is None
def test_put_and_get_collections(self):
"""测试集合类型数据的存取"""
pool = MultiProcessingSharedPool()
pool.clear()
# 列表
test_list = [1, 2, 3, "a", "b", "c"]
pool.put("list_key", test_list)
assert pool.get("list_key") == test_list
# 字典
test_dict = {"name": "test", "value": 100, "nested": {"a": 1}}
pool.put("dict_key", test_dict)
assert pool.get("dict_key") == test_dict
# 元组
test_tuple = (1, 2, 3)
pool.put("tuple_key", test_tuple)
assert pool.get("tuple_key") == test_tuple
# 集合
test_set = {1, 2, 3}
pool.put("set_key", test_set)
assert pool.get("set_key") == test_set
def test_exists(self):
"""测试 exists 方法"""
pool = MultiProcessingSharedPool()
pool.clear()
assert not pool.exists("nonexistent_key")
pool.put("existing_key", "value")
assert pool.exists("existing_key")
pool.remove("existing_key")
assert not pool.exists("existing_key")
def test_remove(self):
"""测试 remove 方法"""
pool = MultiProcessingSharedPool()
pool.clear()
# 移除存在的 key
pool.put("key_to_remove", "value")
assert pool.remove("key_to_remove") is True
assert not pool.exists("key_to_remove")
# 移除不存在的 key
assert pool.remove("nonexistent_key") is False
def test_pop(self):
"""测试 pop 方法"""
pool = MultiProcessingSharedPool()
pool.clear()
# pop 存在的 key
pool.put("key_to_pop", "popped_value")
value = pool.pop("key_to_pop")
assert value == "popped_value"
assert not pool.exists("key_to_pop")
# pop 不存在的 key带默认值
default_value = pool.pop("nonexistent_key", "default")
assert default_value == "default"
# pop 不存在的 key不带默认值
assert pool.pop("nonexistent_key") is None
def test_size(self):
"""测试 size 方法"""
pool = MultiProcessingSharedPool()
pool.clear()
assert pool.size() == 0
pool.put("key1", "value1")
assert pool.size() == 1
pool.put("key2", "value2")
pool.put("key3", "value3")
assert pool.size() == 3
pool.remove("key1")
assert pool.size() == 2
pool.clear()
assert pool.size() == 0
def test_keys(self):
"""测试 keys 方法"""
pool = MultiProcessingSharedPool()
pool.clear()
assert pool.keys() == []
pool.put("key1", "value1")
pool.put("key2", "value2")
pool.put("key3", "value3")
keys = pool.keys()
assert len(keys) == 3
assert set(keys) == {"key1", "key2", "key3"}
def test_clear(self):
"""测试 clear 方法"""
pool = MultiProcessingSharedPool()
pool.put("key1", "value1")
pool.put("key2", "value2")
pool.clear()
assert pool.size() == 0
assert pool.keys() == []
assert not pool.exists("key1")
class TestDictInterface:
"""测试字典风格接口"""
def test_getitem(self):
"""测试 __getitem__"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("test_key", "test_value")
assert pool["test_key"] == "test_value"
# 访问不存在的 key 应抛出 KeyError
with pytest.raises(KeyError):
_ = pool["nonexistent_key"]
def test_setitem(self):
"""测试 __setitem__"""
pool = MultiProcessingSharedPool()
pool.clear()
pool["new_key"] = "new_value"
assert pool.get("new_key") == "new_value"
def test_delitem(self):
"""测试 __delitem__"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("key_to_delete", "value")
del pool["key_to_delete"]
assert not pool.exists("key_to_delete")
# 删除不存在的 key 应抛出 KeyError
with pytest.raises(KeyError):
del pool["nonexistent_key"]
def test_contains(self):
"""测试 __contains__ (in 操作符)"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("existing_key", "value")
assert "existing_key" in pool
assert "nonexistent_key" not in pool
class TestContextManager:
"""测试上下文管理器"""
def test_context_manager(self):
"""测试 with 语句支持"""
pool = MultiProcessingSharedPool()
pool.clear()
with MultiProcessingSharedPool() as p:
p.put("ctx_key", "ctx_value")
assert p.get("ctx_key") == "ctx_value"
# 上下文退出后数据应该仍然存在(管理器未关闭)
assert pool.get("ctx_key") == "ctx_value"
class TestErrorHandling:
"""测试错误处理"""
def test_invalid_label_type_put(self):
"""测试 put 时传入非法 label 类型"""
pool = MultiProcessingSharedPool()
with pytest.raises(TypeError, match="Label must be a string"):
pool.put(123, "value")
with pytest.raises(TypeError, match="Label must be a string"):
pool.put(None, "value")
def test_invalid_label_type_get(self):
"""测试 get 时传入非法 label 类型"""
pool = MultiProcessingSharedPool()
with pytest.raises(TypeError, match="Label must be a string"):
pool.get(123)
def test_invalid_label_type_pop(self):
"""测试 pop 时传入非法 label 类型"""
pool = MultiProcessingSharedPool()
with pytest.raises(TypeError, match="Label must be a string"):
pool.pop(123)
def test_get_with_default(self):
"""测试 get 带默认值的场景"""
pool = MultiProcessingSharedPool()
pool.clear()
# key 不存在时返回默认值
assert pool.get("nonexistent", "default") == "default"
assert pool.get("nonexistent", None) is None
assert pool.get("nonexistent") is None
# key 存在时返回实际值
pool.put("existing", "real_value")
assert pool.get("existing", "default") == "real_value"
def test_special_label_names(self):
"""测试特殊 label 名称"""
pool = MultiProcessingSharedPool()
pool.clear()
# 空字符串
pool.put("", "empty_string_key")
assert pool.get("") == "empty_string_key"
# 特殊字符
special_keys = [
"key with spaces",
"key\twith\ttabs",
"key\nwith\nnewlines",
"key/with/slashes",
"key.with.dots",
"key:with:colons",
"UPPERCASE_KEY",
"mixedCase_Key",
"unicode_中文_key",
"emoji_😀_key",
]
for key in special_keys:
pool.put(key, f"value_for_{key}")
assert pool.get(key) == f"value_for_{key}"
class TestOverwrite:
"""测试覆盖写入"""
def test_overwrite_value(self):
"""测试覆盖相同 key 的值"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("key", "original_value")
assert pool.get("key") == "original_value"
pool.put("key", "new_value")
assert pool.get("key") == "new_value"
# 不同类型覆盖
pool.put("key", 12345)
assert pool.get("key") == 12345
def test_overwrite_with_none(self):
"""测试用 None 覆盖有值的 key"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("key", "value")
pool.put("key", None)
assert pool.get("key") is None
assert pool.exists("key") # key 应该仍然存在