524 lines
15 KiB
Python
524 lines
15 KiB
Python
"""
|
|
函数序列化测试 - 验证 cloudpickle 集成
|
|
"""
|
|
|
|
import multiprocessing
|
|
import time
|
|
import math
|
|
import pytest
|
|
from mpsp.mpsp import MultiProcessingSharedPool
|
|
|
|
|
|
# ==================== 模块级别的普通函数 ====================
|
|
|
|
|
|
def simple_function(x):
|
|
"""简单的加法函数"""
|
|
return x + 1
|
|
|
|
|
|
def multiply_function(a, b):
|
|
"""乘法函数"""
|
|
return a * b
|
|
|
|
|
|
def function_with_default(x, y=10):
|
|
"""带默认参数的函数"""
|
|
return x + y
|
|
|
|
|
|
def function_with_kwargs(*args, **kwargs):
|
|
"""带可变参数的函数"""
|
|
return sum(args) + sum(kwargs.values())
|
|
|
|
|
|
def recursive_function(n):
|
|
"""递归函数"""
|
|
if n <= 1:
|
|
return 1
|
|
return n * recursive_function(n - 1)
|
|
|
|
|
|
def closure_factory(base):
|
|
"""闭包工厂函数"""
|
|
|
|
def inner(x):
|
|
return x + base
|
|
|
|
return inner
|
|
|
|
|
|
# ==================== 辅助函数(模块级别定义)====================
|
|
|
|
|
|
def worker_execute_function(key, result_queue):
|
|
"""子进程:获取函数并执行"""
|
|
pool = MultiProcessingSharedPool()
|
|
func = pool.get(key)
|
|
if func is None:
|
|
result_queue.put(None)
|
|
return
|
|
|
|
# 执行函数(根据不同的测试函数传入不同参数)
|
|
try:
|
|
if key == "simple_func":
|
|
result = func(5)
|
|
elif key == "multiply_func":
|
|
result = func(3, 4)
|
|
elif key == "default_func":
|
|
result = func(5)
|
|
elif key == "kwargs_func":
|
|
result = func(1, 2, 3, a=4, b=5)
|
|
elif key == "recursive_func":
|
|
result = func(5)
|
|
elif key == "closure_func":
|
|
result = func(10)
|
|
elif key == "lambda_func":
|
|
result = func(7)
|
|
elif key == "lambda_with_capture":
|
|
result = func()
|
|
elif key == "nested_func":
|
|
result = func(3)
|
|
else:
|
|
result = func()
|
|
result_queue.put(result)
|
|
except Exception as e:
|
|
result_queue.put(f"ERROR: {e}")
|
|
|
|
|
|
def worker_execute_lambda_with_arg(key, arg, result_queue):
|
|
"""子进程:获取 lambda 并执行,传入参数"""
|
|
pool = MultiProcessingSharedPool()
|
|
func = pool.get(key)
|
|
if func is None:
|
|
result_queue.put(None)
|
|
return
|
|
result_queue.put(func(arg))
|
|
|
|
|
|
def get_lambda_description(func):
|
|
"""获取 lambda 函数的描述字符串"""
|
|
try:
|
|
return func.__name__
|
|
except AttributeError:
|
|
return str(func)
|
|
|
|
|
|
# ==================== 测试类 ====================
|
|
|
|
|
|
class TestNormalFunctions:
|
|
"""测试普通函数的序列化和反序列化"""
|
|
|
|
def test_simple_function(self):
|
|
"""测试简单函数的传递"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
# 存储函数
|
|
result = pool.put("simple_func", simple_function)
|
|
assert result is True
|
|
|
|
# 当前进程验证
|
|
retrieved = pool.get("simple_func")
|
|
assert retrieved is not None
|
|
assert retrieved(5) == 6
|
|
assert retrieved(10) == 11
|
|
|
|
def test_simple_function_cross_process(self):
|
|
"""测试简单函数的跨进程传递"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
# 父进程存储函数
|
|
pool.put("simple_func", simple_function)
|
|
|
|
# 子进程获取并执行
|
|
result_queue = multiprocessing.Queue()
|
|
p = multiprocessing.Process(
|
|
target=worker_execute_function, args=("simple_func", result_queue)
|
|
)
|
|
p.start()
|
|
p.join()
|
|
|
|
result = result_queue.get()
|
|
assert result == 6 # simple_function(5) = 5 + 1
|
|
|
|
def test_function_with_multiple_args(self):
|
|
"""测试多参数函数"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
pool.put("multiply_func", multiply_function)
|
|
|
|
result_queue = multiprocessing.Queue()
|
|
p = multiprocessing.Process(
|
|
target=worker_execute_function, args=("multiply_func", result_queue)
|
|
)
|
|
p.start()
|
|
p.join()
|
|
|
|
result = result_queue.get()
|
|
assert result == 12 # multiply_function(3, 4) = 12
|
|
|
|
def test_function_with_default_args(self):
|
|
"""测试带默认参数的函数"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
pool.put("default_func", function_with_default)
|
|
|
|
result_queue = multiprocessing.Queue()
|
|
p = multiprocessing.Process(
|
|
target=worker_execute_function, args=("default_func", result_queue)
|
|
)
|
|
p.start()
|
|
p.join()
|
|
|
|
result = result_queue.get()
|
|
assert result == 15 # function_with_default(5) = 5 + 10
|
|
|
|
def test_function_with_kwargs(self):
|
|
"""测试带 **kwargs 的函数"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
pool.put("kwargs_func", function_with_kwargs)
|
|
|
|
result_queue = multiprocessing.Queue()
|
|
p = multiprocessing.Process(
|
|
target=worker_execute_function, args=("kwargs_func", result_queue)
|
|
)
|
|
p.start()
|
|
p.join()
|
|
|
|
result = result_queue.get()
|
|
assert result == 15 # sum(1,2,3) + sum(4,5) = 6 + 9 = 15
|
|
|
|
def test_recursive_function(self):
|
|
"""测试递归函数"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
pool.put("recursive_func", recursive_function)
|
|
|
|
result_queue = multiprocessing.Queue()
|
|
p = multiprocessing.Process(
|
|
target=worker_execute_function, args=("recursive_func", result_queue)
|
|
)
|
|
p.start()
|
|
p.join()
|
|
|
|
result = result_queue.get()
|
|
assert result == 120 # 5! = 120
|
|
|
|
|
|
class TestLambdaFunctions:
|
|
"""测试 Lambda 函数的序列化和反序列化"""
|
|
|
|
def test_simple_lambda(self):
|
|
"""测试简单 lambda 函数"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
simple_lambda = lambda x: x * 2
|
|
result = pool.put("lambda_func", simple_lambda)
|
|
assert result is True
|
|
|
|
# 当前进程验证
|
|
retrieved = pool.get("lambda_func")
|
|
assert retrieved(5) == 10
|
|
assert retrieved(7) == 14
|
|
|
|
def test_simple_lambda_cross_process(self):
|
|
"""测试简单 lambda 的跨进程传递"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
simple_lambda = lambda x: x * 3
|
|
pool.put("lambda_func", simple_lambda)
|
|
|
|
result_queue = multiprocessing.Queue()
|
|
p = multiprocessing.Process(
|
|
target=worker_execute_function, args=("lambda_func", result_queue)
|
|
)
|
|
p.start()
|
|
p.join()
|
|
|
|
result = result_queue.get()
|
|
assert result == 21 # lambda(7) = 7 * 3 = 21
|
|
|
|
def test_lambda_with_capture(self):
|
|
"""测试捕获外部变量的 lambda"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
captured_value = 100
|
|
capturing_lambda = lambda: captured_value + 1
|
|
|
|
pool.put("lambda_with_capture", capturing_lambda)
|
|
|
|
result_queue = multiprocessing.Queue()
|
|
p = multiprocessing.Process(
|
|
target=worker_execute_function, args=("lambda_with_capture", result_queue)
|
|
)
|
|
p.start()
|
|
p.join()
|
|
|
|
result = result_queue.get()
|
|
assert result == 101 # captured_value + 1 = 101
|
|
|
|
def test_lambda_in_list_comprehension(self):
|
|
"""测试在列表推导式中创建的 lambda"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
# 创建多个 lambda
|
|
lambdas = [(lambda x, i=i: x + i) for i in range(5)]
|
|
|
|
for i, lam in enumerate(lambdas):
|
|
pool.put(f"lambda_{i}", lam)
|
|
|
|
# 验证每个 lambda 都能正确捕获各自的 i
|
|
for i in range(5):
|
|
retrieved = pool.get(f"lambda_{i}")
|
|
assert retrieved(10) == 10 + i
|
|
|
|
def test_complex_lambda(self):
|
|
"""测试复杂的 lambda 表达式"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
complex_lambda = lambda x, y: (x**2 + y**2) ** 0.5
|
|
pool.put("complex_lambda", complex_lambda)
|
|
|
|
# 子进程验证
|
|
def worker_execute_complex_lambda(key, x, y, result_queue):
|
|
pool = MultiProcessingSharedPool()
|
|
func = pool.get(key)
|
|
result_queue.put(func(x, y))
|
|
|
|
result_queue = multiprocessing.Queue()
|
|
p = multiprocessing.Process(
|
|
target=worker_execute_complex_lambda,
|
|
args=("complex_lambda", 3, 4, result_queue),
|
|
)
|
|
p.start()
|
|
p.join()
|
|
|
|
result = result_queue.get()
|
|
assert abs(result - 5.0) < 1e-10 # sqrt(3^2 + 4^2) = 5
|
|
|
|
|
|
class TestNestedFunctions:
|
|
"""测试嵌套函数(在函数内部定义的函数)"""
|
|
|
|
def test_nested_function(self):
|
|
"""测试嵌套函数"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
def outer_function(x):
|
|
def inner_function(y):
|
|
return y * y
|
|
|
|
return inner_function(x) + x
|
|
|
|
pool.put("nested_func", outer_function)
|
|
|
|
result_queue = multiprocessing.Queue()
|
|
p = multiprocessing.Process(
|
|
target=worker_execute_function, args=("nested_func", result_queue)
|
|
)
|
|
p.start()
|
|
p.join()
|
|
|
|
result = result_queue.get()
|
|
assert result == 12 # outer_function(3) = 3*3 + 3 = 12
|
|
|
|
def test_closure_function(self):
|
|
"""测试闭包函数"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
closure_func = closure_factory(100)
|
|
pool.put("closure_func", closure_func)
|
|
|
|
result_queue = multiprocessing.Queue()
|
|
p = multiprocessing.Process(
|
|
target=worker_execute_function, args=("closure_func", result_queue)
|
|
)
|
|
p.start()
|
|
p.join()
|
|
|
|
result = result_queue.get()
|
|
assert result == 110 # closure_func(10) = 10 + 100
|
|
|
|
def test_multiple_closures(self):
|
|
"""测试多个闭包"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
closures = [closure_factory(i) for i in range(5)]
|
|
for i, closure in enumerate(closures):
|
|
pool.put(f"closure_{i}", closure)
|
|
|
|
# 验证每个闭包捕获的值不同
|
|
for i in range(5):
|
|
retrieved = pool.get(f"closure_{i}")
|
|
assert retrieved(10) == 10 + i
|
|
|
|
|
|
class TestClassMethods:
|
|
"""测试类方法的序列化"""
|
|
|
|
def test_static_method(self):
|
|
"""测试静态方法"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
class Calculator:
|
|
@staticmethod
|
|
def add(x, y):
|
|
return x + y
|
|
|
|
@staticmethod
|
|
def multiply(x, y):
|
|
return x * y
|
|
|
|
pool.put("static_add", Calculator.add)
|
|
pool.put("static_multiply", Calculator.multiply)
|
|
|
|
# 验证静态方法
|
|
add_func = pool.get("static_add")
|
|
multiply_func = pool.get("static_multiply")
|
|
|
|
assert add_func(2, 3) == 5
|
|
assert multiply_func(2, 3) == 6
|
|
|
|
def test_class_method(self):
|
|
"""测试类方法"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
class Counter:
|
|
count = 0
|
|
|
|
@classmethod
|
|
def increment(cls):
|
|
cls.count += 1
|
|
return cls.count
|
|
|
|
# 注意:类方法通常不能被 cloudpickle 正确序列化
|
|
# 因为它依赖于类定义
|
|
result = pool.put("class_method", Counter.increment)
|
|
# 如果成功存储,尝试执行
|
|
if result:
|
|
try:
|
|
method = pool.get("class_method")
|
|
# 类方法在反序列化后可能无法正常工作
|
|
# 这取决于 cloudpickle 的实现
|
|
except Exception:
|
|
pass # 预期可能失败
|
|
|
|
|
|
class TestBuiltInFunctions:
|
|
"""测试内置函数的序列化"""
|
|
|
|
def test_builtin_functions(self):
|
|
"""测试 Python 内置函数"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
# 大多数内置函数可以用标准 pickle 序列化
|
|
pool.put("builtin_sum", sum)
|
|
pool.put("builtin_max", max)
|
|
pool.put("builtin_min", min)
|
|
pool.put("builtin_len", len)
|
|
|
|
# 验证
|
|
assert pool.get("builtin_sum")([1, 2, 3]) == 6
|
|
assert pool.get("builtin_max")([1, 2, 3]) == 3
|
|
assert pool.get("builtin_min")([1, 2, 3]) == 1
|
|
assert pool.get("builtin_len")([1, 2, 3]) == 3
|
|
|
|
def test_math_functions(self):
|
|
"""测试 math 模块函数"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
pool.put("math_sqrt", math.sqrt)
|
|
pool.put("math_sin", math.sin)
|
|
pool.put("math_cos", math.cos)
|
|
|
|
# 验证
|
|
assert abs(pool.get("math_sqrt")(16) - 4.0) < 1e-10
|
|
assert abs(pool.get("math_sin")(0) - 0.0) < 1e-10
|
|
assert abs(pool.get("math_cos")(0) - 1.0) < 1e-10
|
|
|
|
|
|
class TestFunctionReturnValues:
|
|
"""测试函数作为返回值"""
|
|
|
|
def test_function_returned_from_function(self):
|
|
"""测试返回函数的函数"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
def create_multiplier(factor):
|
|
return lambda x: x * factor
|
|
|
|
pool.put("create_multiplier", create_multiplier)
|
|
|
|
# 在子进程中获取并使用
|
|
def worker_get_multiplier(result_queue):
|
|
pool = MultiProcessingSharedPool()
|
|
factory = pool.get("create_multiplier")
|
|
multiplier_func = factory(5) # 创建一个乘以 5 的函数
|
|
result_queue.put(multiplier_func(10))
|
|
|
|
result_queue = multiprocessing.Queue()
|
|
p = multiprocessing.Process(target=worker_get_multiplier, args=(result_queue,))
|
|
p.start()
|
|
p.join()
|
|
|
|
result = result_queue.get()
|
|
assert result == 50 # 10 * 5 = 50
|
|
|
|
|
|
class TestErrorHandling:
|
|
"""测试函数序列化的错误处理"""
|
|
|
|
def test_unpicklable_function_fallback(self):
|
|
"""测试无法序列化的函数回退到 cloudpickle"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
# 创建局部函数(无法被标准 pickle 序列化)
|
|
def local_function(x):
|
|
return x**2
|
|
|
|
# 应该通过 cloudpickle 成功存储
|
|
result = pool.put("local_func", local_function)
|
|
assert result is True
|
|
|
|
# 验证可以正确执行
|
|
retrieved = pool.get("local_func")
|
|
assert retrieved(5) == 25
|
|
|
|
def test_function_with_unpicklable_capture(self):
|
|
"""测试捕获不可序列化对象的函数"""
|
|
pool = MultiProcessingSharedPool()
|
|
pool.clear()
|
|
|
|
# 捕获文件对象(不可序列化)
|
|
try:
|
|
with open(__file__, "r") as f:
|
|
file_capturing_lambda = lambda: f.read()
|
|
|
|
# 尝试存储应该失败
|
|
result = pool.put("file_lambda", file_capturing_lambda)
|
|
# 如果 cloudpickle 支持,验证是否能正确失败
|
|
except Exception:
|
|
pass # 预期可能失败
|