Basic script for OLMoE inference.

Add: basic script for OLMoE inference and saving expert logits.

Add: OLMoE vLLM module with logits saving.

Add: vLLM model registering.

Fix: update DataLogger module for multiprocessing.
This commit is contained in:
2025-09-28 11:41:30 +08:00
parent 5942bbfd05
commit 6738855fb1
4 changed files with 470 additions and 374 deletions

94
olmoe_log_expert_vllm.py Normal file
View File

@ -0,0 +1,94 @@
# %%
import gc
from datetime import datetime
from pathlib import Path
import torch
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel
from vllm.distributed.parallel_state import destroy_model_parallel
from models.register import register_vllm_logit_logging_models
from utils import DataLogger as dlog
# %%
# dlog.get_instance(path=f"olmoe_{datetime.now().strftime("%Y%m%d-%H%M%S")}.parquet")
# %%
model_id = "./llms/OLMoE-1B-7B-0924-Instruct"
try:
log_file = Path(f"olmoe_{datetime.now().strftime("%Y%m%d-%H%M%S")}.parquet")
if log_file.exists():
log_file.unlink()
dlog.initialize(path=log_file)
register_vllm_logit_logging_models()
llm = LLM(
model=model_id,
cpu_offload_gb=4,
# tensor_parallel_size=2,
gpu_memory_utilization=0.95,
max_model_len=4096,
# compilation_config=CompilationConfig(
# level=CompilationLevel.PIECEWISE,
# # By default, it goes up to max_num_seqs
# cudagraph_capture_sizes=[1, 2, 4, 8, 16],
# ),
enforce_eager=True,
)
sampling_params = SamplingParams(
temperature=0.6,
top_p=0.95,
top_k=20,
max_tokens=1024,
)
# Prepare the input to the model
prompt = "Give me a very short introduction to large language models."
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
]
# messages = [
# {"role": "system", "content": "你是一位人工智能助手。"},
# {"role": "user", "content": "请简要地介绍什么是大语言模型。"},
# ]
# Generate outputs
outputs = llm.chat(
messages,
sampling_params=sampling_params,
# chat_template_kwargs={"enable_thinking": True}, # Set to False to strictly disable thinking
)
# Print the outputs.
for out in outputs:
# out.prompt is the input prompt; out.outputs is a list of completion choices
# print("=== PROMPT ===")
# print(out.prompt)
# print("=== COMPLETION ===")
print(out.outputs[0].text)
print("\n---\n")
dlog.log({
"_time": datetime.now(),
"output_text": out.outputs[0].text
})
print("Finish completion")
except Exception as e:
print(e)
finally:
if llm := globals().get("llm", None):
if engine := getattr(llm, "llm_engine", None):
# llm.llm_engine
del engine
del llm
destroy_model_parallel()
torch.cuda.empty_cache()
gc.collect()