Implement OlMoE logits logging.

Update: logging items in OlMoE.

Update: some changes for main script.

Add: logits post processing notebook.
This commit is contained in:
2025-09-29 19:52:44 +08:00
parent 6738855fb1
commit 5d1d818138
5 changed files with 99 additions and 22 deletions

View File

@ -223,3 +223,6 @@ venv/
# Checkpoints for LLMs
llms/
# Logging files
logs/

3
.gitignore vendored
View File

@ -223,3 +223,6 @@ venv/
# Checkpoints for LLMs
llms/
# Logging files
logs/

59
logits_post_process.py Normal file
View File

@ -0,0 +1,59 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.17.3
# kernelspec:
# display_name: venv
# language: python
# name: python3
# ---
# %%
import numpy as np
import pandas as pd
# import torch
# %%
log_file = "olmoe_20250929-165403.parquet"
model_id = "./llms/OLMoE-1B-7B-0924-Instruct"
# %%
df = pd.read_parquet(log_file)
df.head()
# %%
# logit = df.loc[1, "router_logits"]
# logit
# %%
main_mask = (df["src"] == "main")
token_ids = df.loc[main_mask, "token_ids"].item()
output_text = df.loc[main_mask, "output_text"].item()
# %%
lm_mask = (df["src"] == "lm_logit")
df.loc[lm_mask, "logits"] = df.loc[lm_mask, "logits"].apply(lambda arr: np.stack([a for a in arr]).flatten())
df.loc[lm_mask, "token_id"] = df.loc[lm_mask, "logits"].apply(lambda l: np.argmax(l, axis=-1))
df.head()
# %%
df[lm_mask, "token_id"].to_numpy()
# %%
token_ids
# %%
import transformers
from transformers import AutoTokenizer, GPTNeoXTokenizerFast
tokenizer: GPTNeoXTokenizerFast = AutoTokenizer.from_pretrained(model_id)
# %%
tokenizer.decode(token_ids)

View File

@ -102,14 +102,18 @@ class OlmoeMoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
dlog.log({
"_time": datetime.now(),
"router_logits": router_logits.cpu().float().numpy(),
"layer": self.layer_idx,
})
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
dlog.log({
"_time": datetime.now(),
"src": "router",
"layer": self.layer_idx,
"router_logits": router_logits.cpu().float().numpy(),
"orig_shape": list(orig_shape),
"hidden_dim": hidden_dim,
"hidden_states_shape": list(hidden_states.shape),
})
return final_hidden_states.view(orig_shape)
def add_logging_metrics(self, layer_idx: int):
@ -501,6 +505,11 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
dlog.log({
"_time": datetime.now(),
"src": "lm_logit",
"logits": logits.cpu().float().numpy(),
})
return logits
def load_weights(self, weights: Iterable[tuple[str,

View File

@ -11,14 +11,12 @@ from vllm.distributed.parallel_state import destroy_model_parallel
from models.register import register_vllm_logit_logging_models
from utils import DataLogger as dlog
# %%
# dlog.get_instance(path=f"olmoe_{datetime.now().strftime("%Y%m%d-%H%M%S")}.parquet")
# %%
model_id = "./llms/OLMoE-1B-7B-0924-Instruct"
try:
log_file = Path(f"olmoe_{datetime.now().strftime("%Y%m%d-%H%M%S")}.parquet")
log_file = Path(f"logs/olmoe_{datetime.now().strftime("%Y%m%d-%H%M%S")}.parquet")
if log_file.exists():
log_file.unlink()
@ -32,6 +30,7 @@ try:
# tensor_parallel_size=2,
gpu_memory_utilization=0.95,
max_model_len=4096,
max_num_seqs=1,
# compilation_config=CompilationConfig(
# level=CompilationLevel.PIECEWISE,
# # By default, it goes up to max_num_seqs
@ -42,21 +41,21 @@ try:
sampling_params = SamplingParams(
temperature=0.6,
top_p=0.95,
top_k=20,
# top_p=0.95,
# top_k=20,
top_p=1.0,
top_k=1,
max_tokens=1024,
)
# Prepare the input to the model
prompt = "Give me a very short introduction to large language models."
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
{
"role": "user",
"content": "What is the name of the tallest building in Paris? Output the final answer ONLY:",
},
]
# messages = [
# {"role": "system", "content": "你是一位人工智能助手。"},
# {"role": "user", "content": "请简要地介绍什么是大语言模型。"},
# ]
# Generate outputs
outputs = llm.chat(
@ -73,12 +72,16 @@ try:
# print("=== COMPLETION ===")
print(out.outputs[0].text)
print("\n---\n")
dlog.log({
dlog.log(
{
"_time": datetime.now(),
"output_text": out.outputs[0].text
})
"src": "main",
"output_text": out.outputs[0].text,
"token_ids": out.outputs[0].token_ids,
}
)
print("Finish completion")
print("\n---\nFinish completion")
except Exception as e:
print(e)