Implement OlMoE logits logging.
Update: logging items in OlMoE. Update: some changes for main script. Add: logits post processing notebook.
This commit is contained in:
@ -223,3 +223,6 @@ venv/
|
||||
|
||||
# Checkpoints for LLMs
|
||||
llms/
|
||||
|
||||
# Logging files
|
||||
logs/
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -223,3 +223,6 @@ venv/
|
||||
|
||||
# Checkpoints for LLMs
|
||||
llms/
|
||||
|
||||
# Logging files
|
||||
logs/
|
59
logits_post_process.py
Normal file
59
logits_post_process.py
Normal file
@ -0,0 +1,59 @@
|
||||
# ---
|
||||
# jupyter:
|
||||
# jupytext:
|
||||
# formats: ipynb,py:percent
|
||||
# text_representation:
|
||||
# extension: .py
|
||||
# format_name: percent
|
||||
# format_version: '1.3'
|
||||
# jupytext_version: 1.17.3
|
||||
# kernelspec:
|
||||
# display_name: venv
|
||||
# language: python
|
||||
# name: python3
|
||||
# ---
|
||||
|
||||
# %%
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
# import torch
|
||||
|
||||
# %%
|
||||
log_file = "olmoe_20250929-165403.parquet"
|
||||
model_id = "./llms/OLMoE-1B-7B-0924-Instruct"
|
||||
|
||||
# %%
|
||||
df = pd.read_parquet(log_file)
|
||||
df.head()
|
||||
|
||||
# %%
|
||||
# logit = df.loc[1, "router_logits"]
|
||||
# logit
|
||||
|
||||
# %%
|
||||
main_mask = (df["src"] == "main")
|
||||
|
||||
token_ids = df.loc[main_mask, "token_ids"].item()
|
||||
output_text = df.loc[main_mask, "output_text"].item()
|
||||
|
||||
# %%
|
||||
lm_mask = (df["src"] == "lm_logit")
|
||||
|
||||
df.loc[lm_mask, "logits"] = df.loc[lm_mask, "logits"].apply(lambda arr: np.stack([a for a in arr]).flatten())
|
||||
df.loc[lm_mask, "token_id"] = df.loc[lm_mask, "logits"].apply(lambda l: np.argmax(l, axis=-1))
|
||||
df.head()
|
||||
|
||||
# %%
|
||||
df[lm_mask, "token_id"].to_numpy()
|
||||
|
||||
# %%
|
||||
token_ids
|
||||
|
||||
# %%
|
||||
import transformers
|
||||
from transformers import AutoTokenizer, GPTNeoXTokenizerFast
|
||||
|
||||
tokenizer: GPTNeoXTokenizerFast = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# %%
|
||||
tokenizer.decode(token_ids)
|
@ -102,14 +102,18 @@ class OlmoeMoE(nn.Module):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
dlog.log({
|
||||
"_time": datetime.now(),
|
||||
"router_logits": router_logits.cpu().float().numpy(),
|
||||
"layer": self.layer_idx,
|
||||
})
|
||||
|
||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
|
||||
dlog.log({
|
||||
"_time": datetime.now(),
|
||||
"src": "router",
|
||||
"layer": self.layer_idx,
|
||||
"router_logits": router_logits.cpu().float().numpy(),
|
||||
"orig_shape": list(orig_shape),
|
||||
"hidden_dim": hidden_dim,
|
||||
"hidden_states_shape": list(hidden_states.shape),
|
||||
})
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
def add_logging_metrics(self, layer_idx: int):
|
||||
@ -501,6 +505,11 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
dlog.log({
|
||||
"_time": datetime.now(),
|
||||
"src": "lm_logit",
|
||||
"logits": logits.cpu().float().numpy(),
|
||||
})
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
|
@ -11,14 +11,12 @@ from vllm.distributed.parallel_state import destroy_model_parallel
|
||||
from models.register import register_vllm_logit_logging_models
|
||||
from utils import DataLogger as dlog
|
||||
|
||||
# %%
|
||||
# dlog.get_instance(path=f"olmoe_{datetime.now().strftime("%Y%m%d-%H%M%S")}.parquet")
|
||||
|
||||
# %%
|
||||
model_id = "./llms/OLMoE-1B-7B-0924-Instruct"
|
||||
|
||||
try:
|
||||
log_file = Path(f"olmoe_{datetime.now().strftime("%Y%m%d-%H%M%S")}.parquet")
|
||||
log_file = Path(f"logs/olmoe_{datetime.now().strftime("%Y%m%d-%H%M%S")}.parquet")
|
||||
if log_file.exists():
|
||||
log_file.unlink()
|
||||
|
||||
@ -32,6 +30,7 @@ try:
|
||||
# tensor_parallel_size=2,
|
||||
gpu_memory_utilization=0.95,
|
||||
max_model_len=4096,
|
||||
max_num_seqs=1,
|
||||
# compilation_config=CompilationConfig(
|
||||
# level=CompilationLevel.PIECEWISE,
|
||||
# # By default, it goes up to max_num_seqs
|
||||
@ -42,21 +41,21 @@ try:
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.6,
|
||||
top_p=0.95,
|
||||
top_k=20,
|
||||
# top_p=0.95,
|
||||
# top_k=20,
|
||||
top_p=1.0,
|
||||
top_k=1,
|
||||
max_tokens=1024,
|
||||
)
|
||||
|
||||
# Prepare the input to the model
|
||||
prompt = "Give me a very short introduction to large language models."
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": prompt},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What is the name of the tallest building in Paris? Output the final answer ONLY:",
|
||||
},
|
||||
]
|
||||
# messages = [
|
||||
# {"role": "system", "content": "你是一位人工智能助手。"},
|
||||
# {"role": "user", "content": "请简要地介绍什么是大语言模型。"},
|
||||
# ]
|
||||
|
||||
# Generate outputs
|
||||
outputs = llm.chat(
|
||||
@ -73,12 +72,16 @@ try:
|
||||
# print("=== COMPLETION ===")
|
||||
print(out.outputs[0].text)
|
||||
print("\n---\n")
|
||||
dlog.log({
|
||||
dlog.log(
|
||||
{
|
||||
"_time": datetime.now(),
|
||||
"output_text": out.outputs[0].text
|
||||
})
|
||||
"src": "main",
|
||||
"output_text": out.outputs[0].text,
|
||||
"token_ids": out.outputs[0].token_ids,
|
||||
}
|
||||
)
|
||||
|
||||
print("Finish completion")
|
||||
print("\n---\nFinish completion")
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
Reference in New Issue
Block a user