Basic script for OLMoE inference.

Add: basic script for OLMoE inference and saving expert logits.

Add: OLMoE vLLM module with logits saving.

Add: vLLM model registering.

Fix: update DataLogger module for multiprocessing.
This commit is contained in:
2025-09-28 11:41:30 +08:00
parent 5942bbfd05
commit 6738855fb1
4 changed files with 470 additions and 374 deletions

View File

@ -17,6 +17,7 @@
This file is origin from vllm/model_executor/models/olmoe.py
"""
from collections.abc import Iterable
from datetime import datetime
from functools import partial
from typing import Any, Optional, Union
@ -51,7 +52,7 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, is_pp_missing_p
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from utils import DataLogger
from utils import DataLogger as dlog
logger = init_logger(__name__)
@ -100,10 +101,20 @@ class OlmoeMoE(nn.Module):
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
dlog.log({
"_time": datetime.now(),
"router_logits": router_logits.cpu().float().numpy(),
"layer": self.layer_idx,
})
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
return final_hidden_states.view(orig_shape)
def add_logging_metrics(self, layer_idx: int):
self.layer_idx = layer_idx
class OlmoeAttention(nn.Module):
@ -234,6 +245,16 @@ class OlmoeDecoderLayer(nn.Module):
prefix=f"{prefix}.self_attn",
)
# Extract layer_idx from prefix
self.layer_idx = None
try:
# Prefix format: "model.layers.7"
parts = prefix.split('.')
if len(parts) >= 3 and parts[2].isdigit():
self.layer_idx = int(parts[2])
except ValueError:
pass
self.mlp = OlmoeMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
@ -245,6 +266,8 @@ class OlmoeDecoderLayer(nn.Module):
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
self.mlp.add_logging_metrics(layer_idx=self.layer_idx)
def forward(
self,
positions: torch.Tensor,

31
models/register.py Normal file
View File

@ -0,0 +1,31 @@
from vllm import ModelRegistry
def register_vllm_logit_logging_models():
models_mapper = {"OlmoeForCausalLM": "models.log_expert.olmoe:OlmoeForCausalLM"}
for model_name, model_path in models_mapper.items():
ModelRegistry.register_model(model_name, model_path)
# ModelRegistry.register_model(
# "Qwen3MoeForCausalLM", "src.modeling_vllm_save.qwen3_moe:Qwen3MoeForCausalLM"
# )
# ModelRegistry.register_model(
# "MixtralForCausalLM", "src.modeling_vllm_save.mixtral:MixtralForCausalLM"
# )
# ModelRegistry.register_model(
# "OlmoeForCausalLM", "models.log_expert.olmoe:OlmoeForCausalLM"
# )
# ModelRegistry.register_model(
# "DeepseekV2ForCausalLM",
# "src.modeling_vllm_save.deepseek_v2:DeepseekV2ForCausalLM",
# )
# ModelRegistry.register_model(
# "Llama4ForConditionalGeneration",
# "src.modeling_vllm_save.mllama4:Llama4ForConditionalGeneration",
# )
# ModelRegistry.register_model(
# "GptOssForCausalLM", "src.modeling_vllm_save.gpt_oss:GptOssForCausalLM"
# )
# ModelRegistry.register_model(
# "PhiMoEForCausalLM", "src.modeling_vllm_save.phimoe:PhiMoEForCausalLM"
# )

94
olmoe_log_expert_vllm.py Normal file
View File

@ -0,0 +1,94 @@
# %%
import gc
from datetime import datetime
from pathlib import Path
import torch
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel
from vllm.distributed.parallel_state import destroy_model_parallel
from models.register import register_vllm_logit_logging_models
from utils import DataLogger as dlog
# %%
# dlog.get_instance(path=f"olmoe_{datetime.now().strftime("%Y%m%d-%H%M%S")}.parquet")
# %%
model_id = "./llms/OLMoE-1B-7B-0924-Instruct"
try:
log_file = Path(f"olmoe_{datetime.now().strftime("%Y%m%d-%H%M%S")}.parquet")
if log_file.exists():
log_file.unlink()
dlog.initialize(path=log_file)
register_vllm_logit_logging_models()
llm = LLM(
model=model_id,
cpu_offload_gb=4,
# tensor_parallel_size=2,
gpu_memory_utilization=0.95,
max_model_len=4096,
# compilation_config=CompilationConfig(
# level=CompilationLevel.PIECEWISE,
# # By default, it goes up to max_num_seqs
# cudagraph_capture_sizes=[1, 2, 4, 8, 16],
# ),
enforce_eager=True,
)
sampling_params = SamplingParams(
temperature=0.6,
top_p=0.95,
top_k=20,
max_tokens=1024,
)
# Prepare the input to the model
prompt = "Give me a very short introduction to large language models."
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
]
# messages = [
# {"role": "system", "content": "你是一位人工智能助手。"},
# {"role": "user", "content": "请简要地介绍什么是大语言模型。"},
# ]
# Generate outputs
outputs = llm.chat(
messages,
sampling_params=sampling_params,
# chat_template_kwargs={"enable_thinking": True}, # Set to False to strictly disable thinking
)
# Print the outputs.
for out in outputs:
# out.prompt is the input prompt; out.outputs is a list of completion choices
# print("=== PROMPT ===")
# print(out.prompt)
# print("=== COMPLETION ===")
print(out.outputs[0].text)
print("\n---\n")
dlog.log({
"_time": datetime.now(),
"output_text": out.outputs[0].text
})
print("Finish completion")
except Exception as e:
print(e)
finally:
if llm := globals().get("llm", None):
if engine := getattr(llm, "llm_engine", None):
# llm.llm_engine
del engine
del llm
destroy_model_parallel()
torch.cuda.empty_cache()
gc.collect()

View File

@ -1,53 +1,69 @@
"""
Asynchronous, batched, and schema-evolving Parquet logger.
Process-safe, asynchronous, batched, and schema-evolving Parquet logger.
This module provides the `DataLogger`, a high-performance logger for structured
data, designed for applications like machine learning experiments, simulations,
or any scenario requiring efficient serialization of row-based data.
data, redesigned with a client-server architecture to be fully compatible with
multi-process applications (e.g., using `multiprocessing` or libraries like vLLM).
Key Features:
- **Unified Interface**: Log data via a simple `DataLogger.log({"key": "value"})` call.
- **Asynchronous & Batched**: A dedicated background thread handles I/O,
batching rows to minimize disk writes and reduce application latency.
- **Schema Evolution**: Automatically adapts the Parquet schema if new data fields
are introduced, rewriting the file to maintain a consistent structure.
- **Singleton Pattern**: A global singleton instance is managed automatically,
providing a convenient, fire-and-forget logging experience.
Key Architectural Features:
- **Client-Server Model**: A single, dedicated server process manages all file I/O,
preventing race conditions and data loss from child processes.
- **Process-Safe**: All processes (main, children) act as clients, sending data
via a managed, shared queue, ensuring centralized and ordered logging.
- **Lazy Initialization & Automatic Management**: The server process is transparently
started on the first log call and is gracefully shut down on program exit via
`atexit`, requiring no manual management from the user.
- **Automatic Server-Side Timestamping**: Optionally, the server can automatically
add a timestamp to each log record, indicating the exact time of its arrival.
- **Unified Interface**: Log data from any process via the simple and consistent
`DataLogger.log({"key": "value"})` call.
- **Asynchronous & Batched**: The server process handles I/O asynchronously from the
clients, batching rows to minimize disk writes and reduce application latency.
- **Schema Evolution**: Reuses the robust logic to automatically adapt the Parquet
schema if new data fields are introduced.
- **Type Handling**: Natively handles Python primitives, NumPy arrays, and PyTorch
tensors, converting them to Parquet-compatible formats.
- **Robust & Thread-Safe**: Designed for use in multi-threaded environments.
tensors.
Basic Usage:
-------------
.. code-block:: python
from logger.data_logger import DataLogger
from data_logger import DataLogger
import multiprocessing
# The first call creates and configures the singleton logger.
# A timestamped filename is generated by default.
DataLogger.log({"step": 0, "loss": 10.5, "accuracy": 0.5})
DataLogger.log({"step": 1, "loss": 9.8, "accuracy": 0.55})
# Optional: Configure the logger at the start of your application.
# If not called, defaults will be used on the first .log() call.
DataLogger.initialize("my_experiment.parquet")
# For the singleton, data is automatically flushed and saved on program exit.
# No explicit `close()` call is required for this simple case.
def worker_function(worker_id):
for i in range(5):
# Log from a child process. It's that simple.
DataLogger.log({"worker_id": worker_id, "step": i, "value": i * 100})
time.sleep(0.1)
Advanced Usage (Instance-Based):
---------------------------------
.. code-block:: python
# All calls to DataLogger.log() from any process will be sent
# to the central logging server.
DataLogger.log({"main_process_event": "starting workers"})
from logger.data_logger import DataLogger, LoggerConfig
processes = [multiprocessing.Process(target=worker_function, args=(i,)) for i in range(3)]
for p in processes:
p.start()
for p in processes:
p.join()
config = LoggerConfig(batch_size=512, flush_interval=5.0)
with DataLogger("my_experiment.parquet", config=config) as logger:
for i in range(1000):
logger.submit({"value": i})
# The `with` statement ensures flush and close on exit.
DataLogger.log({"main_process_event": "workers finished"})
# Data is automatically flushed and the server is shut down when the main
# program exits. For explicit control, you can call:
# DataLogger.close()
"""
from __future__ import annotations
import datetime
import atexit
import datetime
import multiprocessing
import multiprocessing.synchronize
import os
import queue
import threading
@ -57,35 +73,33 @@ import typing as t
from dataclasses import dataclass
from pathlib import Path
# Third-party libraries are imported with runtime checks to provide clear
# error messages if they are not installed.
# Third-party libraries are imported with runtime checks
try:
import numpy as np
except ImportError:
np = None # type: ignore
np = None
try:
import pandas as pd
except ImportError:
raise ImportError(
"pandas is required for DataLogger. Install with `pip install pandas`."
)
raise ImportError("pandas is required. Install with `pip install pandas`.")
try:
import pyarrow as pa
import pyarrow.parquet as pq
except ImportError:
raise ImportError(
"pyarrow is required for DataLogger. Install with `pip install pyarrow`."
)
raise ImportError("pyarrow is required. Install with `pip install pyarrow`.")
try:
import torch
except ImportError:
torch = None # type: ignore
torch = None
# Type alias for a single row of data.
# --- Type Definitions and Constants ---
Row = t.Dict[str, t.Any]
# Special command objects to be sent through the queue, distinct from data rows (dicts)
_FLUSH_COMMAND = "__FLUSH__"
_SHUTDOWN_COMMAND = "__SHUTDOWN__"
@dataclass
@ -93,312 +107,133 @@ class LoggerConfig:
"""Configuration for the DataLogger's writer behavior."""
batch_size: int = 1024
"""Number of rows to accumulate before writing a batch to the Parquet file."""
flush_interval: float = 1.0
"""Maximum time in seconds to wait before flushing the buffer, even if
`batch_size` is not reached."""
parquet_compression: str = "snappy"
"""Compression codec to use for the Parquet file.
Common options: 'snappy', 'gzip', 'brotli', 'none'."""
allow_schema_rewrite: bool = True
"""If True, the logger will automatically rewrite the entire Parquet file to
accommodate new columns. If False, it will raise an error."""
log_server_time: bool = False
class DataLogger:
class DataLoggerServer(multiprocessing.Process):
"""
An asynchronous, batched logger that writes data to a Parquet file.
The server process responsible for all file I/O operations.
This class manages a background thread to handle file I/O, allowing the
calling application to log data with minimal blocking. It supports schema
evolution, making it robust to changes in data structure over time.
This process runs a loop that consumes data from a shared queue, batches it,
and writes to a Parquet file. It is the single source of truth for the log file,
ensuring that writes are serialized and schema evolution is handled correctly.
It is designed to be managed by the `DataLogger` facade and not instantiated directly.
"""
_singleton: t.Optional["DataLogger"] = None
_singleton_lock = threading.Lock()
# --- Public API ---
@classmethod
def get_instance(
cls,
path: t.Optional[t.Union[str, Path]] = None,
config: t.Optional[LoggerConfig] = None,
) -> "DataLogger":
"""
Get or create the global singleton instance of the DataLogger.
The first time this method is called, it creates a new `DataLogger`
instance and registers a cleanup function via `atexit` to ensure
`close()` is called automatically upon program termination.
Subsequent calls will ignore the arguments and return the existing
instance.
Args:
path: The file path for the log file. If None, a timestamped
filename like 'log_YYYYMMDD-HHMMSS.parquet' is created in the
current working directory.
config: A `LoggerConfig` object to configure the writer's behavior.
If None, default settings are used.
Returns:
The singleton `DataLogger` instance.
"""
if cls._singleton is None:
with cls._singleton_lock:
if cls._singleton is None:
# Create the singleton instance.
instance = cls(path, config)
# Register its close method to be called at program exit.
# This ensures data is saved even if the user forgets to call close().
atexit.register(instance.close)
cls._singleton = instance
return cls._singleton
@classmethod
def log(cls, row: Row) -> None:
"""
Log a data row using the singleton instance.
This is a convenience method that lazily initializes the singleton on
its first call. The operation is non-blocking; the data is placed in
an internal queue to be processed by the background writer thread.
Args:
row: A dictionary representing a single row of data, where keys
are column names and values are the data points.
"""
instance = cls.get_instance()
instance.submit(row)
def __init__(
self,
path: t.Optional[t.Union[str, Path]] = None,
config: t.Optional[LoggerConfig] = None,
log_queue: multiprocessing.Queue,
path: Path,
config: LoggerConfig,
flush_event: multiprocessing.synchronize.Event,
):
"""
Initialize a DataLogger instance.
Args:
path: The file path for the log file. If None, a timestamped
filename is automatically generated.
config: A `LoggerConfig` object. If None, default settings are used.
"""
self.path = self._resolve_path(path)
self._config = config or LoggerConfig()
# Internal state for the writer thread
self._queue: queue.Queue[t.Optional[Row]] = queue.Queue()
self._stop_event = threading.Event()
self._flush_event = threading.Event()
self._writer_thread: t.Optional[threading.Thread] = None
self._writer_lock = threading.RLock() # Protects writer and schema
# Parquet-specific state, managed exclusively by the writer thread
self._parquet_writer: t.Optional[pq.ParquetWriter] = None
self._schema: t.Optional[pa.Schema] = None
super().__init__(daemon=True, name=f"DataLoggerServer-{path.name}")
self._queue = log_queue
self.path = path
self._config = config
self._flush_event = flush_event
self._buffer: t.List[Row] = []
self._start_writer_thread()
def submit(self, row: Row) -> None:
"""
Submit a data row to be written asynchronously by the logger instance.
Args:
row: A dictionary representing a single row of data.
Raises:
TypeError: If the provided row is not a dictionary.
RuntimeError: If the logger has already been closed.
"""
if self._stop_event.is_set():
raise RuntimeError("Logger has been closed and cannot accept new data.")
if not isinstance(row, dict):
raise TypeError(f"Expected a dict for a row, but got {type(row)}.")
normalized_row = self._normalize_row(row)
self._queue.put(normalized_row)
def flush(self, timeout: float = 10.0) -> None:
"""
Block until all currently queued and buffered data is written to disk.
Args:
timeout: Maximum time in seconds to wait for the flush to complete.
"""
if self._writer_thread is None or not self._writer_thread.is_alive():
return
self._flush_event.clear()
self._queue.put(None) # Sentinel to trigger a flush
self._flush_event.wait(timeout)
def close(self, timeout: float = 10.0) -> None:
"""
Flush all remaining data and shut down the background writer thread.
This method is idempotent and thread-safe. It is designed to be
called explicitly, via a `with` statement, or automatically at program
exit.
Args:
timeout: Maximum time in seconds to wait for the writer thread
to finish.
"""
if self._stop_event.is_set():
return
self._stop_event.set()
self._queue.put(None) # Wake up the writer thread if it's blocking.
# Do not join the writer thread from itself, which would cause a deadlock.
if self._writer_thread and threading.current_thread() != self._writer_thread:
self._writer_thread.join(timeout)
def __enter__(self) -> "DataLogger":
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Ensures the logger is closed upon exiting a `with` block."""
self.close()
def __del__(self):
"""Ensures data is flushed when the logger object is destroyed."""
self.close()
# --- Internal Methods ---
def _resolve_path(self, path: t.Optional[t.Union[str, Path]]) -> Path:
"""Determine the final output path for the log file."""
if path is None:
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
filename = f"log_{timestamp}.parquet"
return Path.cwd() / filename
resolved_path = Path(path)
if resolved_path.suffix == "":
resolved_path = resolved_path.with_suffix(".parquet")
return resolved_path
def _start_writer_thread(self) -> None:
"""Initialize and start the background writer thread."""
if self._writer_thread is not None:
return
thread_name = f"DataLoggerWriter-{self.path.name}"
self._writer_thread = threading.Thread(
target=self._writer_loop, name=thread_name, daemon=True
)
self._writer_thread.start()
def _writer_loop(self) -> None:
"""
The main loop for the background writer thread.
This loop continuously pulls data from the queue, batches it, and
writes it to the Parquet file. It handles flush signals, stop events,
and schema evolution.
"""
def run(self) -> None:
"""The main loop of the server process."""
try:
while not self._stop_event.is_set():
should_stop = False
while not should_stop:
try:
# Block until an item is available or the flush interval times out.
item = self._queue.get(timeout=self._config.flush_interval)
if isinstance(item, dict):
self._process_log_item(item)
elif item == _FLUSH_COMMAND:
self._write_buffer_if_needed(force=True)
self._flush_event.set() # Signal completion
elif item == _SHUTDOWN_COMMAND:
should_stop = True
self._drain_and_write()
except queue.Empty:
# Timeout occurred, treat as a periodic flush signal.
item = None
self._write_buffer_if_needed(force=False)
if item is not None:
self._buffer.append(item)
# Final write upon graceful shutdown
self._write_buffer_if_needed(force=True)
except Exception as e:
print(f"FATAL: DataLogger server process crashed: {e}", flush=True)
traceback.print_exc()
def _process_log_item(self, item: Row) -> None: # <<< NEW: Encapsulated logic
"""Processes a single log item by timestamping, normalizing, and buffering it."""
if self._config.log_server_time and "_time" not in item:
# Add the server-side timestamp when the item is dequeued.
item["_time"] = datetime.datetime.now()
normalized_row = self._normalize_row(item)
self._buffer.append(normalized_row)
def _write_buffer_if_needed(self, force: bool = False) -> None:
"""
Determines if the buffer should be written to disk and triggers the write.
Args:
force: If True, writes the buffer regardless of size.
Used for flush commands, timeouts, and shutdown.
"""
buffer_size = len(self._buffer)
is_flush_signal = item is None
is_batch_full = buffer_size >= self._config.batch_size
is_shutting_down = self._stop_event.is_set()
if self._buffer and (
is_flush_signal or is_batch_full or is_shutting_down
):
if self._buffer and (force or is_batch_full):
self._write_batch(self._buffer)
self._buffer.clear()
if is_flush_signal:
self._flush_event.set() # Signal that a flush completed
# Final drain of the queue and buffer after the stop event is set.
self._drain_remaining()
except Exception as e:
print(f"FATAL: DataLogger writer thread crashed: {e}", flush=True)
traceback.print_exc()
finally:
# This block ensures that the Parquet writer is always closed
# when the writer thread exits, for any reason.
with self._writer_lock:
if self._parquet_writer:
try:
self._parquet_writer.close()
except Exception as e:
print(
f"ERROR: Exception while closing Parquet writer: {e}",
flush=True,
)
self._parquet_writer = None
def _drain_remaining(self) -> None:
def _drain_and_write(self) -> None:
"""Process all remaining items in the queue and buffer during shutdown."""
while True:
try:
item = self._queue.get_nowait()
if item:
self._buffer.append(item)
if isinstance(item, dict):
self._process_log_item(item)
except queue.Empty:
break
if self._buffer:
self._write_batch(self._buffer)
self._buffer.clear()
self._write_buffer_if_needed(force=True)
def _write_batch(self, rows: t.List[Row]) -> None:
"""
Convert a list of rows into a Parquet table and write it to the file.
Converts rows to a Parquet table and writes it, handling schema evolution.
This method handles schema creation, validation, and evolution.
It is always executed within the writer thread.
This method encapsulates the core file I/O logic. It reads the existing
file, merges data, and atomically overwrites it. This strategy is robust
for schema changes, although it has performance implications for very
large files. This logic is preserved from the original implementation.
"""
if not rows:
return
try:
with self._writer_lock:
df = pd.DataFrame(rows)
# Ensure a consistent column order for schema stability.
df = df.reindex(sorted(df.columns), axis=1)
new_table = pa.Table.from_pandas(df, preserve_index=False)
combined_table: pa.Table
if self.path.exists():
# File exists, need to append or evolve schema
existing_table = pq.read_table(self.path)
existing_schema = existing_table.schema
if existing_schema.equals(new_table.schema):
# Schema matches, append the data
if existing_table.schema.equals(new_table.schema):
combined_table = pa.concat_tables([existing_table, new_table])
else:
# Schema evolution needed
# Schema evolution is needed
if not self._config.allow_schema_rewrite:
raise RuntimeError(
"Schema mismatch detected, and rewriting is disabled. "
f"Existing schema: {existing_schema}, New schema: {new_table.schema}"
"Schema mismatch detected, and rewriting is disabled."
)
print(
f"INFO: Schema evolution detected. Rewriting {self.path}...",
flush=True,
)
# Combine with schema evolution
combined_df = pd.concat(
[existing_table.to_pandas(), new_table.to_pandas()],
ignore_index=True,
@ -411,11 +246,11 @@ class DataLogger:
combined_df, preserve_index=False
)
else:
# New file
# This is a new file
self.path.parent.mkdir(parents=True, exist_ok=True)
combined_table = new_table
# Write the combined table atomically
# Write atomically by using a temporary file
temp_path = self.path.with_suffix(f"{self.path.suffix}.tmp")
pq.write_table(
combined_table,
@ -424,70 +259,183 @@ class DataLogger:
)
os.replace(temp_path, self.path)
# Update our schema tracking
self._schema = combined_table.schema
except Exception as e:
print(f"ERROR: Failed to write batch to {self.path}: {e}", flush=True)
traceback.print_exc()
def _rewrite_with_new_schema(self, new_table: pa.Table) -> None:
"""
Rewrite the entire Parquet file to accommodate an evolved schema.
This is a potentially expensive operation as it reads the entire
existing file into memory.
Args:
new_table: The new batch of data with a different schema.
"""
print(f"INFO: Schema evolution detected. Rewriting {self.path}...", flush=True)
# Close the current writer before reading the file.
if self._parquet_writer:
self._parquet_writer.close()
# Read existing data, combine with new data, and create a unified table.
existing_table = pq.read_table(self.path)
combined_df = pd.concat(
[existing_table.to_pandas(), new_table.to_pandas()],
ignore_index=True,
sort=False,
)
# Re-sort columns for the new unified schema.
combined_df = combined_df.reindex(sorted(combined_df.columns), axis=1)
final_table = pa.Table.from_pandas(combined_df, preserve_index=False)
self._schema = final_table.schema
# Atomically replace the old file with the new one.
temp_path = self.path.with_suffix(f"{self.path.suffix}.tmp")
pq.write_table(
final_table, temp_path, compression=self._config.parquet_compression
)
os.replace(temp_path, self.path)
# Re-initialize the writer with the new schema for subsequent writes.
self._parquet_writer = pq.ParquetWriter(
self.path, self._schema, compression=self._config.parquet_compression
)
def _normalize_row(self, row: Row) -> Row:
"""
Sanitize all values in a row for Parquet compatibility.
"""
"""Sanitize all values in a row for Parquet compatibility."""
return {key: self._normalize_value(value) for key, value in row.items()}
def _normalize_value(self, value: t.Any) -> t.Any:
"""
Convert a single value to a Parquet-friendly format.
- NumPy arrays and Torch tensors are converted to nested lists.
- Other types are passed through for pandas to handle.
"""
"""Convert a single value to a Parquet-friendly format."""
if value is None:
return None
if np and isinstance(value, np.ndarray):
return value.tolist()
if torch and isinstance(value, torch.Tensor):
return value.detach().cpu().numpy().tolist()
# return value.detach().cpu().numpy()
return value
class DataLogger:
"""
A process-safe facade for logging structured data to a Parquet file.
This class manages a singleton server process in the background and provides
a simple, unified API (`.log()`) for all processes in an application.
"""
_manager: t.ClassVar[t.Optional[multiprocessing.managers.SyncManager]] = None
_log_queue: t.ClassVar[t.Optional[multiprocessing.Queue]] = None
_server_process: t.ClassVar[t.Optional[DataLoggerServer]] = None
_flush_event: t.ClassVar[t.Optional[multiprocessing.synchronize.Event]] = None
_lock = threading.Lock() # Use a thread-lock for initializing class-level resources
@classmethod
def initialize(
cls,
path: t.Optional[t.Union[str, Path]] = None,
config: t.Optional[LoggerConfig] = None,
) -> None:
"""
Explicitly initialize and start the logging server.
This is optional. If not called, the server will be started automatically
with default settings upon the first call to `log()`.
Args:
path: Path to the log file. If None, a timestamped name is generated.
config: Configuration for the logger's behavior.
"""
with cls._lock:
if cls._server_process is not None:
print(
"WARNING: DataLogger already initialized. Ignoring subsequent call.",
flush=True,
)
return
cls._start_server(path, config)
@classmethod
def log(cls, row: Row) -> None:
"""
Log a data row from any process.
The first call to this method will automatically start the background
logging server if it hasn't been started already. The operation is
non-blocking; the data is placed in a process-safe queue.
Args:
row: A dictionary representing a single row of data.
Raises:
TypeError: If the provided row is not a dictionary.
"""
if not isinstance(row, dict):
raise TypeError(f"Expected a dict for a row, but got {type(row)}.")
with cls._lock:
if cls._server_process is None:
# Lazy initialization with default parameters
cls._start_server()
if cls._log_queue:
try:
cls._log_queue.put(row)
except Exception as e:
# Can happen if the manager process dies unexpectedly
print(f"ERROR: Failed to queue log message: {e}", flush=True)
@classmethod
def flush(cls, timeout: float = 10.0) -> None:
"""
Block until all currently queued data is written to disk.
Args:
timeout: Maximum time in seconds to wait for the flush to complete.
"""
if (
cls._server_process is None
or cls._log_queue is None
or cls._flush_event is None
):
return
cls._flush_event.clear()
cls._log_queue.put(_FLUSH_COMMAND)
cls._flush_event.wait(timeout)
@classmethod
def close(cls, timeout: float = 10.0) -> None:
"""
Flush all data and gracefully shut down the logging server.
This is automatically registered with `atexit` and usually does not
need to be called manually.
Args:
timeout: Maximum time to wait for the server process to join.
"""
with cls._lock:
if cls._server_process is None or not cls._server_process.is_alive():
return
print("INFO: Shutting down DataLogger server...", flush=True)
if cls._log_queue:
cls._log_queue.put(_SHUTDOWN_COMMAND)
cls._server_process.join(timeout)
if cls._server_process.is_alive():
print(
"WARNING: DataLogger server did not shut down cleanly. Terminating.",
flush=True,
)
cls._server_process.terminate()
if cls._manager:
cls._manager.shutdown()
cls._server_process = None
cls._log_queue = None
cls._manager = None
print("INFO: DataLogger shutdown complete.", flush=True)
@classmethod
def _start_server(
cls,
path: t.Optional[t.Union[str, Path]] = None,
config: t.Optional[LoggerConfig] = None,
) -> None:
"""Internal method to create and start the server process. Must be called within a lock."""
resolved_path = cls._resolve_path(path)
print(f"INFO: Starting DataLogger server for -> {resolved_path}", flush=True)
# A manager handles shared state between processes
cls._manager = multiprocessing.Manager()
cls._log_queue = cls._manager.Queue()
cls._flush_event = cls._manager.Event()
cls._server_process = DataLoggerServer(
log_queue=cls._log_queue,
path=resolved_path,
config=config or LoggerConfig(),
flush_event=cls._flush_event,
)
cls._server_process.start()
# Register the cleanup function to be called on program exit
atexit.register(cls.close)
@staticmethod
def _resolve_path(path: t.Optional[t.Union[str, Path]]) -> Path:
"""Determine the final output path for the log file."""
if path is None:
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
filename = f"log_{timestamp}.parquet"
return Path.cwd() / filename
resolved_path = Path(path)
if resolved_path.suffix == "":
resolved_path = resolved_path.with_suffix(".parquet")
return resolved_path