""" Unit tests for the MultiProcessingSharedPool core functionality. These tests verify the singleton pattern, basic CRUD operations, error handling, and dictionary-like interface. """ import pytest import multiprocessing as mp import time from typing import Any, Dict from easy_mp_share import MultiProcessingSharedPool, shared_pool class TestMultiProcessingSharedPool: """Test cases for MultiProcessingSharedPool class.""" def setup_method(self): """Set up test fixtures before each test method.""" # Get a fresh instance for each test self.pool = MultiProcessingSharedPool() self.pool.clear() # Clear any existing data def test_singleton_pattern(self): """Test that MultiProcessingSharedPool follows singleton pattern.""" pool1 = MultiProcessingSharedPool() pool2 = MultiProcessingSharedPool() # Should be the same instance assert pool1 is pool2 # get_instance should also return the same instance pool3 = MultiProcessingSharedPool.get_instance() assert pool1 is pool3 def test_basic_put_and_get(self): """Test basic put and get operations.""" # Test string data assert self.pool.put("test_key", "test_value") assert self.pool.get("test_key") == "test_value" # Test numeric data assert self.pool.put("number", 42) assert self.pool.get("number") == 42 # Test list data test_list = [1, 2, 3, "four"] assert self.pool.put("my_list", test_list) assert self.pool.get("my_list") == test_list # Test dict data test_dict = {"key1": "value1", "key2": 123} assert self.pool.put("my_dict", test_dict) assert self.pool.get("my_dict") == test_dict def test_get_with_default(self): """Test get operation with default values.""" # Non-existent key should return None by default assert self.pool.get("non_existent") is None # Non-existent key should return provided default assert self.pool.get("non_existent", "default_value") == "default_value" assert self.pool.get("non_existent", 42) == 42 # Existent key should return actual value, not default self.pool.put("existing", "actual_value") assert self.pool.get("existing", "default") == "actual_value" def test_exists_method(self): """Test the exists method.""" assert not self.pool.exists("non_existent") self.pool.put("existing_key", "value") assert self.pool.exists("existing_key") # Test with empty string key self.pool.put("", "empty_key_value") assert self.pool.exists("") # Test with special characters self.pool.put("key!@#$%", "special_chars") assert self.pool.exists("key!@#$%") def test_remove_method(self): """Test the remove method.""" # Remove non-existent key should return False assert not self.pool.remove("non_existent") # Remove existing key should return True self.pool.put("to_remove", "value") assert self.pool.remove("to_remove") assert not self.pool.exists("to_remove") # Should not be able to get removed key assert self.pool.get("to_remove") is None def test_pop_method(self): """Test the pop method.""" # Pop non-existent key should return default assert self.pool.pop("non_existent") is None assert self.pool.pop("non_existent", "default") == "default" # Pop existing key should return value and remove it self.pool.put("to_pop", "popped_value") result = self.pool.pop("to_pop") assert result == "popped_value" assert not self.pool.exists("to_pop") # Second pop should return default assert self.pool.pop("to_pop", "gone") == "gone" def test_size_method(self): """Test the size method.""" assert self.pool.size() == 0 self.pool.put("key1", "value1") assert self.pool.size() == 1 self.pool.put("key2", "value2") assert self.pool.size() == 2 self.pool.remove("key1") assert self.pool.size() == 1 self.pool.clear() assert self.pool.size() == 0 def test_keys_method(self): """Test the keys method.""" assert self.pool.keys() == [] # Add some keys self.pool.put("first", "value1") self.pool.put("second", "value2") self.pool.put("third", "value3") keys = self.pool.keys() assert len(keys) == 3 assert "first" in keys assert "second" in keys assert "third" in keys # Remove a key self.pool.remove("second") keys = self.pool.keys() assert len(keys) == 2 assert "second" not in keys def test_clear_method(self): """Test the clear method.""" self.pool.put("key1", "value1") self.pool.put("key2", "value2") assert self.pool.size() == 2 self.pool.clear() assert self.pool.size() == 0 assert self.pool.keys() == [] assert self.pool.get("key1") is None assert self.pool.get("key2") is None def test_dictionary_interface(self): """Test dictionary-like interface.""" # Test __setitem__ and __getitem__ self.pool["dict_key"] = "dict_value" assert self.pool["dict_key"] == "dict_value" # Test __contains__ assert "dict_key" in self.pool assert "non_existent" not in self.pool # Test __delitem__ del self.pool["dict_key"] assert "dict_key" not in self.pool # Test KeyError for non-existent key with pytest.raises(KeyError): _ = self.pool["non_existent"] with pytest.raises(KeyError): del self.pool["non_existent"] def test_type_validation(self): """Test type validation for keys.""" # Non-string keys should raise TypeError for put, get, and pop with pytest.raises(TypeError): self.pool.put(123, "value") with pytest.raises(TypeError): self.pool.get(123) with pytest.raises(TypeError): self.pool.pop(123) # Note: __contains__ doesn't raise TypeError as it delegates to exists() # which doesn't have type validation for performance reasons # This is the current behavior of the implementation result = 123 in self.pool # Should return False, not raise assert result is False def test_context_manager(self): """Test context manager interface.""" with MultiProcessingSharedPool() as pool: pool.put("context_key", "context_value") assert pool.get("context_key") == "context_value" # Pool should still be usable after context pool = MultiProcessingSharedPool() assert pool.get("context_key") == "context_value" def test_shared_pool_convenience_instance(self): """Test the shared_pool convenience instance.""" # Should be the same as MultiProcessingSharedPool.get_instance() pool = MultiProcessingSharedPool.get_instance() assert shared_pool is pool # Should work like any other instance shared_pool.put("convenience_key", "convenience_value") assert shared_pool.get("convenience_key") == "convenience_value" # Should be accessible from other instances pool = MultiProcessingSharedPool() assert pool.get("convenience_key") == "convenience_value" @pytest.mark.skipif( not hasattr(mp, 'Process') or mp.get_start_method() != 'spawn', reason="Multiprocessing test requires spawn method" ) def test_multiprocess_access(self): """Test that the pool works across multiple processes.""" def worker_function(results_queue: mp.Queue, worker_id: int): """Worker function for multiprocess testing.""" pool = MultiProcessingSharedPool() # Put some data pool.put(f"worker_{worker_id}_data", f"value_from_worker_{worker_id}") # Get data from main process main_data = pool.get("main_data") # Get data from previous workers if worker_id > 1: prev_data = pool.get(f"worker_{worker_id-1}_data") results_queue.put(prev_data) results_queue.put(main_data) results_queue.put(f"worker_{worker_id}_data") # Put initial data from main process self.pool.put("main_data", "value_from_main") results_queue = mp.Queue() processes = [] # Start multiple worker processes for i in range(1, 4): p = mp.Process(target=worker_function, args=(results_queue, i)) processes.append(p) p.start() # Wait for all processes to complete for p in processes: p.join() # Collect results results = [] while not results_queue.empty(): results.append(results_queue.get()) # Verify data sharing worked assert "value_from_main" in results assert "worker_1_data" in results assert "worker_2_data" in results assert "worker_3_data" in results # Verify data is accessible from main process assert self.pool.get("worker_1_data") == "value_from_worker_1" assert self.pool.get("worker_2_data") == "value_from_worker_2" assert self.pool.get("worker_3_data") == "value_from_worker_3" if __name__ == "__main__": pytest.main([__file__])