""" 函数序列化测试 - 验证 cloudpickle 集成 """ import multiprocessing import time import math import pytest from mpsp.mpsp import MultiProcessingSharedPool # ==================== 模块级别的普通函数 ==================== def simple_function(x): """简单的加法函数""" return x + 1 def multiply_function(a, b): """乘法函数""" return a * b def function_with_default(x, y=10): """带默认参数的函数""" return x + y def function_with_kwargs(*args, **kwargs): """带可变参数的函数""" return sum(args) + sum(kwargs.values()) def recursive_function(n): """递归函数""" if n <= 1: return 1 return n * recursive_function(n - 1) def closure_factory(base): """闭包工厂函数""" def inner(x): return x + base return inner # ==================== 辅助函数(模块级别定义)==================== def worker_execute_function(key, result_queue): """子进程:获取函数并执行""" pool = MultiProcessingSharedPool() func = pool.get(key) if func is None: result_queue.put(None) return # 执行函数(根据不同的测试函数传入不同参数) try: if key == "simple_func": result = func(5) elif key == "multiply_func": result = func(3, 4) elif key == "default_func": result = func(5) elif key == "kwargs_func": result = func(1, 2, 3, a=4, b=5) elif key == "recursive_func": result = func(5) elif key == "closure_func": result = func(10) elif key == "lambda_func": result = func(7) elif key == "lambda_with_capture": result = func() elif key == "nested_func": result = func(3) else: result = func() result_queue.put(result) except Exception as e: result_queue.put(f"ERROR: {e}") def worker_execute_lambda_with_arg(key, arg, result_queue): """子进程:获取 lambda 并执行,传入参数""" pool = MultiProcessingSharedPool() func = pool.get(key) if func is None: result_queue.put(None) return result_queue.put(func(arg)) def get_lambda_description(func): """获取 lambda 函数的描述字符串""" try: return func.__name__ except AttributeError: return str(func) # ==================== 测试类 ==================== class TestNormalFunctions: """测试普通函数的序列化和反序列化""" def test_simple_function(self): """测试简单函数的传递""" pool = MultiProcessingSharedPool() pool.clear() # 存储函数 result = pool.put("simple_func", simple_function) assert result is True # 当前进程验证 retrieved = pool.get("simple_func") assert retrieved is not None assert retrieved(5) == 6 assert retrieved(10) == 11 def test_simple_function_cross_process(self): """测试简单函数的跨进程传递""" pool = MultiProcessingSharedPool() pool.clear() # 父进程存储函数 pool.put("simple_func", simple_function) # 子进程获取并执行 result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_execute_function, args=("simple_func", result_queue) ) p.start() p.join() result = result_queue.get() assert result == 6 # simple_function(5) = 5 + 1 def test_function_with_multiple_args(self): """测试多参数函数""" pool = MultiProcessingSharedPool() pool.clear() pool.put("multiply_func", multiply_function) result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_execute_function, args=("multiply_func", result_queue) ) p.start() p.join() result = result_queue.get() assert result == 12 # multiply_function(3, 4) = 12 def test_function_with_default_args(self): """测试带默认参数的函数""" pool = MultiProcessingSharedPool() pool.clear() pool.put("default_func", function_with_default) result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_execute_function, args=("default_func", result_queue) ) p.start() p.join() result = result_queue.get() assert result == 15 # function_with_default(5) = 5 + 10 def test_function_with_kwargs(self): """测试带 **kwargs 的函数""" pool = MultiProcessingSharedPool() pool.clear() pool.put("kwargs_func", function_with_kwargs) result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_execute_function, args=("kwargs_func", result_queue) ) p.start() p.join() result = result_queue.get() assert result == 15 # sum(1,2,3) + sum(4,5) = 6 + 9 = 15 def test_recursive_function(self): """测试递归函数""" pool = MultiProcessingSharedPool() pool.clear() pool.put("recursive_func", recursive_function) result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_execute_function, args=("recursive_func", result_queue) ) p.start() p.join() result = result_queue.get() assert result == 120 # 5! = 120 class TestLambdaFunctions: """测试 Lambda 函数的序列化和反序列化""" def test_simple_lambda(self): """测试简单 lambda 函数""" pool = MultiProcessingSharedPool() pool.clear() simple_lambda = lambda x: x * 2 result = pool.put("lambda_func", simple_lambda) assert result is True # 当前进程验证 retrieved = pool.get("lambda_func") assert retrieved(5) == 10 assert retrieved(7) == 14 def test_simple_lambda_cross_process(self): """测试简单 lambda 的跨进程传递""" pool = MultiProcessingSharedPool() pool.clear() simple_lambda = lambda x: x * 3 pool.put("lambda_func", simple_lambda) result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_execute_function, args=("lambda_func", result_queue) ) p.start() p.join() result = result_queue.get() assert result == 21 # lambda(7) = 7 * 3 = 21 def test_lambda_with_capture(self): """测试捕获外部变量的 lambda""" pool = MultiProcessingSharedPool() pool.clear() captured_value = 100 capturing_lambda = lambda: captured_value + 1 pool.put("lambda_with_capture", capturing_lambda) result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_execute_function, args=("lambda_with_capture", result_queue) ) p.start() p.join() result = result_queue.get() assert result == 101 # captured_value + 1 = 101 def test_lambda_in_list_comprehension(self): """测试在列表推导式中创建的 lambda""" pool = MultiProcessingSharedPool() pool.clear() # 创建多个 lambda lambdas = [(lambda x, i=i: x + i) for i in range(5)] for i, lam in enumerate(lambdas): pool.put(f"lambda_{i}", lam) # 验证每个 lambda 都能正确捕获各自的 i for i in range(5): retrieved = pool.get(f"lambda_{i}") assert retrieved(10) == 10 + i def test_complex_lambda(self): """测试复杂的 lambda 表达式""" pool = MultiProcessingSharedPool() pool.clear() complex_lambda = lambda x, y: (x**2 + y**2) ** 0.5 pool.put("complex_lambda", complex_lambda) # 子进程验证 def worker_execute_complex_lambda(key, x, y, result_queue): pool = MultiProcessingSharedPool() func = pool.get(key) result_queue.put(func(x, y)) result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_execute_complex_lambda, args=("complex_lambda", 3, 4, result_queue), ) p.start() p.join() result = result_queue.get() assert abs(result - 5.0) < 1e-10 # sqrt(3^2 + 4^2) = 5 class TestNestedFunctions: """测试嵌套函数(在函数内部定义的函数)""" def test_nested_function(self): """测试嵌套函数""" pool = MultiProcessingSharedPool() pool.clear() def outer_function(x): def inner_function(y): return y * y return inner_function(x) + x pool.put("nested_func", outer_function) result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_execute_function, args=("nested_func", result_queue) ) p.start() p.join() result = result_queue.get() assert result == 12 # outer_function(3) = 3*3 + 3 = 12 def test_closure_function(self): """测试闭包函数""" pool = MultiProcessingSharedPool() pool.clear() closure_func = closure_factory(100) pool.put("closure_func", closure_func) result_queue = multiprocessing.Queue() p = multiprocessing.Process( target=worker_execute_function, args=("closure_func", result_queue) ) p.start() p.join() result = result_queue.get() assert result == 110 # closure_func(10) = 10 + 100 def test_multiple_closures(self): """测试多个闭包""" pool = MultiProcessingSharedPool() pool.clear() closures = [closure_factory(i) for i in range(5)] for i, closure in enumerate(closures): pool.put(f"closure_{i}", closure) # 验证每个闭包捕获的值不同 for i in range(5): retrieved = pool.get(f"closure_{i}") assert retrieved(10) == 10 + i class TestClassMethods: """测试类方法的序列化""" def test_static_method(self): """测试静态方法""" pool = MultiProcessingSharedPool() pool.clear() class Calculator: @staticmethod def add(x, y): return x + y @staticmethod def multiply(x, y): return x * y pool.put("static_add", Calculator.add) pool.put("static_multiply", Calculator.multiply) # 验证静态方法 add_func = pool.get("static_add") multiply_func = pool.get("static_multiply") assert add_func(2, 3) == 5 assert multiply_func(2, 3) == 6 def test_class_method(self): """测试类方法""" pool = MultiProcessingSharedPool() pool.clear() class Counter: count = 0 @classmethod def increment(cls): cls.count += 1 return cls.count # 注意:类方法通常不能被 cloudpickle 正确序列化 # 因为它依赖于类定义 result = pool.put("class_method", Counter.increment) # 如果成功存储,尝试执行 if result: try: method = pool.get("class_method") # 类方法在反序列化后可能无法正常工作 # 这取决于 cloudpickle 的实现 except Exception: pass # 预期可能失败 class TestBuiltInFunctions: """测试内置函数的序列化""" def test_builtin_functions(self): """测试 Python 内置函数""" pool = MultiProcessingSharedPool() pool.clear() # 大多数内置函数可以用标准 pickle 序列化 pool.put("builtin_sum", sum) pool.put("builtin_max", max) pool.put("builtin_min", min) pool.put("builtin_len", len) # 验证 assert pool.get("builtin_sum")([1, 2, 3]) == 6 assert pool.get("builtin_max")([1, 2, 3]) == 3 assert pool.get("builtin_min")([1, 2, 3]) == 1 assert pool.get("builtin_len")([1, 2, 3]) == 3 def test_math_functions(self): """测试 math 模块函数""" pool = MultiProcessingSharedPool() pool.clear() pool.put("math_sqrt", math.sqrt) pool.put("math_sin", math.sin) pool.put("math_cos", math.cos) # 验证 assert abs(pool.get("math_sqrt")(16) - 4.0) < 1e-10 assert abs(pool.get("math_sin")(0) - 0.0) < 1e-10 assert abs(pool.get("math_cos")(0) - 1.0) < 1e-10 class TestFunctionReturnValues: """测试函数作为返回值""" def test_function_returned_from_function(self): """测试返回函数的函数""" pool = MultiProcessingSharedPool() pool.clear() def create_multiplier(factor): return lambda x: x * factor pool.put("create_multiplier", create_multiplier) # 在子进程中获取并使用 def worker_get_multiplier(result_queue): pool = MultiProcessingSharedPool() factory = pool.get("create_multiplier") multiplier_func = factory(5) # 创建一个乘以 5 的函数 result_queue.put(multiplier_func(10)) result_queue = multiprocessing.Queue() p = multiprocessing.Process(target=worker_get_multiplier, args=(result_queue,)) p.start() p.join() result = result_queue.get() assert result == 50 # 10 * 5 = 50 class TestErrorHandling: """测试函数序列化的错误处理""" def test_unpicklable_function_fallback(self): """测试无法序列化的函数回退到 cloudpickle""" pool = MultiProcessingSharedPool() pool.clear() # 创建局部函数(无法被标准 pickle 序列化) def local_function(x): return x**2 # 应该通过 cloudpickle 成功存储 result = pool.put("local_func", local_function) assert result is True # 验证可以正确执行 retrieved = pool.get("local_func") assert retrieved(5) == 25 def test_function_with_unpicklable_capture(self): """测试捕获不可序列化对象的函数""" pool = MultiProcessingSharedPool() pool.clear() # 捕获文件对象(不可序列化) try: with open(__file__, "r") as f: file_capturing_lambda = lambda: f.read() # 尝试存储应该失败 result = pool.put("file_lambda", file_capturing_lambda) # 如果 cloudpickle 支持,验证是否能正确失败 except Exception: pass # 预期可能失败