Update: logging items in OlMoE. Update: some changes for main script. Add: logits post processing notebook.
98 lines
2.5 KiB
Python
98 lines
2.5 KiB
Python
# %%
|
|
import gc
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
from vllm import LLM, SamplingParams
|
|
from vllm.config import CompilationConfig, CompilationLevel
|
|
from vllm.distributed.parallel_state import destroy_model_parallel
|
|
|
|
from models.register import register_vllm_logit_logging_models
|
|
from utils import DataLogger as dlog
|
|
|
|
|
|
# %%
|
|
model_id = "./llms/OLMoE-1B-7B-0924-Instruct"
|
|
|
|
try:
|
|
log_file = Path(f"logs/olmoe_{datetime.now().strftime("%Y%m%d-%H%M%S")}.parquet")
|
|
if log_file.exists():
|
|
log_file.unlink()
|
|
|
|
dlog.initialize(path=log_file)
|
|
|
|
register_vllm_logit_logging_models()
|
|
|
|
llm = LLM(
|
|
model=model_id,
|
|
cpu_offload_gb=4,
|
|
# tensor_parallel_size=2,
|
|
gpu_memory_utilization=0.95,
|
|
max_model_len=4096,
|
|
max_num_seqs=1,
|
|
# compilation_config=CompilationConfig(
|
|
# level=CompilationLevel.PIECEWISE,
|
|
# # By default, it goes up to max_num_seqs
|
|
# cudagraph_capture_sizes=[1, 2, 4, 8, 16],
|
|
# ),
|
|
enforce_eager=True,
|
|
)
|
|
|
|
sampling_params = SamplingParams(
|
|
temperature=0.6,
|
|
# top_p=0.95,
|
|
# top_k=20,
|
|
top_p=1.0,
|
|
top_k=1,
|
|
max_tokens=1024,
|
|
)
|
|
|
|
# Prepare the input to the model
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{
|
|
"role": "user",
|
|
"content": "What is the name of the tallest building in Paris? Output the final answer ONLY:",
|
|
},
|
|
]
|
|
|
|
# Generate outputs
|
|
outputs = llm.chat(
|
|
messages,
|
|
sampling_params=sampling_params,
|
|
# chat_template_kwargs={"enable_thinking": True}, # Set to False to strictly disable thinking
|
|
)
|
|
|
|
# Print the outputs.
|
|
for out in outputs:
|
|
# out.prompt is the input prompt; out.outputs is a list of completion choices
|
|
# print("=== PROMPT ===")
|
|
# print(out.prompt)
|
|
# print("=== COMPLETION ===")
|
|
print(out.outputs[0].text)
|
|
print("\n---\n")
|
|
dlog.log(
|
|
{
|
|
"_time": datetime.now(),
|
|
"src": "main",
|
|
"output_text": out.outputs[0].text,
|
|
"token_ids": out.outputs[0].token_ids,
|
|
}
|
|
)
|
|
|
|
print("\n---\nFinish completion")
|
|
|
|
except Exception as e:
|
|
print(e)
|
|
|
|
finally:
|
|
if llm := globals().get("llm", None):
|
|
if engine := getattr(llm, "llm_engine", None):
|
|
# llm.llm_engine
|
|
del engine
|
|
del llm
|
|
destroy_model_parallel()
|
|
torch.cuda.empty_cache()
|
|
gc.collect()
|