Initial commit.

This commit is contained in:
2026-02-21 12:00:47 +08:00
commit cd335c1b3f
14 changed files with 3492 additions and 0 deletions

523
test/test_functions.py Normal file
View File

@ -0,0 +1,523 @@
"""
函数序列化测试 - 验证 cloudpickle 集成
"""
import multiprocessing
import time
import math
import pytest
from mpsp.mpsp import MultiProcessingSharedPool
# ==================== 模块级别的普通函数 ====================
def simple_function(x):
"""简单的加法函数"""
return x + 1
def multiply_function(a, b):
"""乘法函数"""
return a * b
def function_with_default(x, y=10):
"""带默认参数的函数"""
return x + y
def function_with_kwargs(*args, **kwargs):
"""带可变参数的函数"""
return sum(args) + sum(kwargs.values())
def recursive_function(n):
"""递归函数"""
if n <= 1:
return 1
return n * recursive_function(n - 1)
def closure_factory(base):
"""闭包工厂函数"""
def inner(x):
return x + base
return inner
# ==================== 辅助函数(模块级别定义)====================
def worker_execute_function(key, result_queue):
"""子进程:获取函数并执行"""
pool = MultiProcessingSharedPool()
func = pool.get(key)
if func is None:
result_queue.put(None)
return
# 执行函数(根据不同的测试函数传入不同参数)
try:
if key == "simple_func":
result = func(5)
elif key == "multiply_func":
result = func(3, 4)
elif key == "default_func":
result = func(5)
elif key == "kwargs_func":
result = func(1, 2, 3, a=4, b=5)
elif key == "recursive_func":
result = func(5)
elif key == "closure_func":
result = func(10)
elif key == "lambda_func":
result = func(7)
elif key == "lambda_with_capture":
result = func()
elif key == "nested_func":
result = func(3)
else:
result = func()
result_queue.put(result)
except Exception as e:
result_queue.put(f"ERROR: {e}")
def worker_execute_lambda_with_arg(key, arg, result_queue):
"""子进程:获取 lambda 并执行,传入参数"""
pool = MultiProcessingSharedPool()
func = pool.get(key)
if func is None:
result_queue.put(None)
return
result_queue.put(func(arg))
def get_lambda_description(func):
"""获取 lambda 函数的描述字符串"""
try:
return func.__name__
except AttributeError:
return str(func)
# ==================== 测试类 ====================
class TestNormalFunctions:
"""测试普通函数的序列化和反序列化"""
def test_simple_function(self):
"""测试简单函数的传递"""
pool = MultiProcessingSharedPool()
pool.clear()
# 存储函数
result = pool.put("simple_func", simple_function)
assert result is True
# 当前进程验证
retrieved = pool.get("simple_func")
assert retrieved is not None
assert retrieved(5) == 6
assert retrieved(10) == 11
def test_simple_function_cross_process(self):
"""测试简单函数的跨进程传递"""
pool = MultiProcessingSharedPool()
pool.clear()
# 父进程存储函数
pool.put("simple_func", simple_function)
# 子进程获取并执行
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_execute_function, args=("simple_func", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == 6 # simple_function(5) = 5 + 1
def test_function_with_multiple_args(self):
"""测试多参数函数"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("multiply_func", multiply_function)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_execute_function, args=("multiply_func", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == 12 # multiply_function(3, 4) = 12
def test_function_with_default_args(self):
"""测试带默认参数的函数"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("default_func", function_with_default)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_execute_function, args=("default_func", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == 15 # function_with_default(5) = 5 + 10
def test_function_with_kwargs(self):
"""测试带 **kwargs 的函数"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("kwargs_func", function_with_kwargs)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_execute_function, args=("kwargs_func", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == 15 # sum(1,2,3) + sum(4,5) = 6 + 9 = 15
def test_recursive_function(self):
"""测试递归函数"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("recursive_func", recursive_function)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_execute_function, args=("recursive_func", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == 120 # 5! = 120
class TestLambdaFunctions:
"""测试 Lambda 函数的序列化和反序列化"""
def test_simple_lambda(self):
"""测试简单 lambda 函数"""
pool = MultiProcessingSharedPool()
pool.clear()
simple_lambda = lambda x: x * 2
result = pool.put("lambda_func", simple_lambda)
assert result is True
# 当前进程验证
retrieved = pool.get("lambda_func")
assert retrieved(5) == 10
assert retrieved(7) == 14
def test_simple_lambda_cross_process(self):
"""测试简单 lambda 的跨进程传递"""
pool = MultiProcessingSharedPool()
pool.clear()
simple_lambda = lambda x: x * 3
pool.put("lambda_func", simple_lambda)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_execute_function, args=("lambda_func", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == 21 # lambda(7) = 7 * 3 = 21
def test_lambda_with_capture(self):
"""测试捕获外部变量的 lambda"""
pool = MultiProcessingSharedPool()
pool.clear()
captured_value = 100
capturing_lambda = lambda: captured_value + 1
pool.put("lambda_with_capture", capturing_lambda)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_execute_function, args=("lambda_with_capture", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == 101 # captured_value + 1 = 101
def test_lambda_in_list_comprehension(self):
"""测试在列表推导式中创建的 lambda"""
pool = MultiProcessingSharedPool()
pool.clear()
# 创建多个 lambda
lambdas = [(lambda x, i=i: x + i) for i in range(5)]
for i, lam in enumerate(lambdas):
pool.put(f"lambda_{i}", lam)
# 验证每个 lambda 都能正确捕获各自的 i
for i in range(5):
retrieved = pool.get(f"lambda_{i}")
assert retrieved(10) == 10 + i
def test_complex_lambda(self):
"""测试复杂的 lambda 表达式"""
pool = MultiProcessingSharedPool()
pool.clear()
complex_lambda = lambda x, y: (x**2 + y**2) ** 0.5
pool.put("complex_lambda", complex_lambda)
# 子进程验证
def worker_execute_complex_lambda(key, x, y, result_queue):
pool = MultiProcessingSharedPool()
func = pool.get(key)
result_queue.put(func(x, y))
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_execute_complex_lambda,
args=("complex_lambda", 3, 4, result_queue),
)
p.start()
p.join()
result = result_queue.get()
assert abs(result - 5.0) < 1e-10 # sqrt(3^2 + 4^2) = 5
class TestNestedFunctions:
"""测试嵌套函数(在函数内部定义的函数)"""
def test_nested_function(self):
"""测试嵌套函数"""
pool = MultiProcessingSharedPool()
pool.clear()
def outer_function(x):
def inner_function(y):
return y * y
return inner_function(x) + x
pool.put("nested_func", outer_function)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_execute_function, args=("nested_func", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == 12 # outer_function(3) = 3*3 + 3 = 12
def test_closure_function(self):
"""测试闭包函数"""
pool = MultiProcessingSharedPool()
pool.clear()
closure_func = closure_factory(100)
pool.put("closure_func", closure_func)
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(
target=worker_execute_function, args=("closure_func", result_queue)
)
p.start()
p.join()
result = result_queue.get()
assert result == 110 # closure_func(10) = 10 + 100
def test_multiple_closures(self):
"""测试多个闭包"""
pool = MultiProcessingSharedPool()
pool.clear()
closures = [closure_factory(i) for i in range(5)]
for i, closure in enumerate(closures):
pool.put(f"closure_{i}", closure)
# 验证每个闭包捕获的值不同
for i in range(5):
retrieved = pool.get(f"closure_{i}")
assert retrieved(10) == 10 + i
class TestClassMethods:
"""测试类方法的序列化"""
def test_static_method(self):
"""测试静态方法"""
pool = MultiProcessingSharedPool()
pool.clear()
class Calculator:
@staticmethod
def add(x, y):
return x + y
@staticmethod
def multiply(x, y):
return x * y
pool.put("static_add", Calculator.add)
pool.put("static_multiply", Calculator.multiply)
# 验证静态方法
add_func = pool.get("static_add")
multiply_func = pool.get("static_multiply")
assert add_func(2, 3) == 5
assert multiply_func(2, 3) == 6
def test_class_method(self):
"""测试类方法"""
pool = MultiProcessingSharedPool()
pool.clear()
class Counter:
count = 0
@classmethod
def increment(cls):
cls.count += 1
return cls.count
# 注意:类方法通常不能被 cloudpickle 正确序列化
# 因为它依赖于类定义
result = pool.put("class_method", Counter.increment)
# 如果成功存储,尝试执行
if result:
try:
method = pool.get("class_method")
# 类方法在反序列化后可能无法正常工作
# 这取决于 cloudpickle 的实现
except Exception:
pass # 预期可能失败
class TestBuiltInFunctions:
"""测试内置函数的序列化"""
def test_builtin_functions(self):
"""测试 Python 内置函数"""
pool = MultiProcessingSharedPool()
pool.clear()
# 大多数内置函数可以用标准 pickle 序列化
pool.put("builtin_sum", sum)
pool.put("builtin_max", max)
pool.put("builtin_min", min)
pool.put("builtin_len", len)
# 验证
assert pool.get("builtin_sum")([1, 2, 3]) == 6
assert pool.get("builtin_max")([1, 2, 3]) == 3
assert pool.get("builtin_min")([1, 2, 3]) == 1
assert pool.get("builtin_len")([1, 2, 3]) == 3
def test_math_functions(self):
"""测试 math 模块函数"""
pool = MultiProcessingSharedPool()
pool.clear()
pool.put("math_sqrt", math.sqrt)
pool.put("math_sin", math.sin)
pool.put("math_cos", math.cos)
# 验证
assert abs(pool.get("math_sqrt")(16) - 4.0) < 1e-10
assert abs(pool.get("math_sin")(0) - 0.0) < 1e-10
assert abs(pool.get("math_cos")(0) - 1.0) < 1e-10
class TestFunctionReturnValues:
"""测试函数作为返回值"""
def test_function_returned_from_function(self):
"""测试返回函数的函数"""
pool = MultiProcessingSharedPool()
pool.clear()
def create_multiplier(factor):
return lambda x: x * factor
pool.put("create_multiplier", create_multiplier)
# 在子进程中获取并使用
def worker_get_multiplier(result_queue):
pool = MultiProcessingSharedPool()
factory = pool.get("create_multiplier")
multiplier_func = factory(5) # 创建一个乘以 5 的函数
result_queue.put(multiplier_func(10))
result_queue = multiprocessing.Queue()
p = multiprocessing.Process(target=worker_get_multiplier, args=(result_queue,))
p.start()
p.join()
result = result_queue.get()
assert result == 50 # 10 * 5 = 50
class TestErrorHandling:
"""测试函数序列化的错误处理"""
def test_unpicklable_function_fallback(self):
"""测试无法序列化的函数回退到 cloudpickle"""
pool = MultiProcessingSharedPool()
pool.clear()
# 创建局部函数(无法被标准 pickle 序列化)
def local_function(x):
return x**2
# 应该通过 cloudpickle 成功存储
result = pool.put("local_func", local_function)
assert result is True
# 验证可以正确执行
retrieved = pool.get("local_func")
assert retrieved(5) == 25
def test_function_with_unpicklable_capture(self):
"""测试捕获不可序列化对象的函数"""
pool = MultiProcessingSharedPool()
pool.clear()
# 捕获文件对象(不可序列化)
try:
with open(__file__, "r") as f:
file_capturing_lambda = lambda: f.read()
# 尝试存储应该失败
result = pool.put("file_lambda", file_capturing_lambda)
# 如果 cloudpickle 支持,验证是否能正确失败
except Exception:
pass # 预期可能失败