Implement OlMoE logits logging.
Update: logging items in OlMoE. Update: some changes for main script. Add: logits post processing notebook.
This commit is contained in:
@ -223,3 +223,6 @@ venv/
|
|||||||
|
|
||||||
# Checkpoints for LLMs
|
# Checkpoints for LLMs
|
||||||
llms/
|
llms/
|
||||||
|
|
||||||
|
# Logging files
|
||||||
|
logs/
|
3
.gitignore
vendored
3
.gitignore
vendored
@ -223,3 +223,6 @@ venv/
|
|||||||
|
|
||||||
# Checkpoints for LLMs
|
# Checkpoints for LLMs
|
||||||
llms/
|
llms/
|
||||||
|
|
||||||
|
# Logging files
|
||||||
|
logs/
|
59
logits_post_process.py
Normal file
59
logits_post_process.py
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
# ---
|
||||||
|
# jupyter:
|
||||||
|
# jupytext:
|
||||||
|
# formats: ipynb,py:percent
|
||||||
|
# text_representation:
|
||||||
|
# extension: .py
|
||||||
|
# format_name: percent
|
||||||
|
# format_version: '1.3'
|
||||||
|
# jupytext_version: 1.17.3
|
||||||
|
# kernelspec:
|
||||||
|
# display_name: venv
|
||||||
|
# language: python
|
||||||
|
# name: python3
|
||||||
|
# ---
|
||||||
|
|
||||||
|
# %%
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
# import torch
|
||||||
|
|
||||||
|
# %%
|
||||||
|
log_file = "olmoe_20250929-165403.parquet"
|
||||||
|
model_id = "./llms/OLMoE-1B-7B-0924-Instruct"
|
||||||
|
|
||||||
|
# %%
|
||||||
|
df = pd.read_parquet(log_file)
|
||||||
|
df.head()
|
||||||
|
|
||||||
|
# %%
|
||||||
|
# logit = df.loc[1, "router_logits"]
|
||||||
|
# logit
|
||||||
|
|
||||||
|
# %%
|
||||||
|
main_mask = (df["src"] == "main")
|
||||||
|
|
||||||
|
token_ids = df.loc[main_mask, "token_ids"].item()
|
||||||
|
output_text = df.loc[main_mask, "output_text"].item()
|
||||||
|
|
||||||
|
# %%
|
||||||
|
lm_mask = (df["src"] == "lm_logit")
|
||||||
|
|
||||||
|
df.loc[lm_mask, "logits"] = df.loc[lm_mask, "logits"].apply(lambda arr: np.stack([a for a in arr]).flatten())
|
||||||
|
df.loc[lm_mask, "token_id"] = df.loc[lm_mask, "logits"].apply(lambda l: np.argmax(l, axis=-1))
|
||||||
|
df.head()
|
||||||
|
|
||||||
|
# %%
|
||||||
|
df[lm_mask, "token_id"].to_numpy()
|
||||||
|
|
||||||
|
# %%
|
||||||
|
token_ids
|
||||||
|
|
||||||
|
# %%
|
||||||
|
import transformers
|
||||||
|
from transformers import AutoTokenizer, GPTNeoXTokenizerFast
|
||||||
|
|
||||||
|
tokenizer: GPTNeoXTokenizerFast = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
|
||||||
|
# %%
|
||||||
|
tokenizer.decode(token_ids)
|
@ -102,14 +102,18 @@ class OlmoeMoE(nn.Module):
|
|||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
|
|
||||||
dlog.log({
|
|
||||||
"_time": datetime.now(),
|
|
||||||
"router_logits": router_logits.cpu().float().numpy(),
|
|
||||||
"layer": self.layer_idx,
|
|
||||||
})
|
|
||||||
|
|
||||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||||
router_logits=router_logits)
|
router_logits=router_logits)
|
||||||
|
|
||||||
|
dlog.log({
|
||||||
|
"_time": datetime.now(),
|
||||||
|
"src": "router",
|
||||||
|
"layer": self.layer_idx,
|
||||||
|
"router_logits": router_logits.cpu().float().numpy(),
|
||||||
|
"orig_shape": list(orig_shape),
|
||||||
|
"hidden_dim": hidden_dim,
|
||||||
|
"hidden_states_shape": list(hidden_states.shape),
|
||||||
|
})
|
||||||
return final_hidden_states.view(orig_shape)
|
return final_hidden_states.view(orig_shape)
|
||||||
|
|
||||||
def add_logging_metrics(self, layer_idx: int):
|
def add_logging_metrics(self, layer_idx: int):
|
||||||
@ -501,6 +505,11 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
|
|||||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
sampling_metadata)
|
sampling_metadata)
|
||||||
|
dlog.log({
|
||||||
|
"_time": datetime.now(),
|
||||||
|
"src": "lm_logit",
|
||||||
|
"logits": logits.cpu().float().numpy(),
|
||||||
|
})
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str,
|
def load_weights(self, weights: Iterable[tuple[str,
|
||||||
|
@ -11,14 +11,12 @@ from vllm.distributed.parallel_state import destroy_model_parallel
|
|||||||
from models.register import register_vllm_logit_logging_models
|
from models.register import register_vllm_logit_logging_models
|
||||||
from utils import DataLogger as dlog
|
from utils import DataLogger as dlog
|
||||||
|
|
||||||
# %%
|
|
||||||
# dlog.get_instance(path=f"olmoe_{datetime.now().strftime("%Y%m%d-%H%M%S")}.parquet")
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
model_id = "./llms/OLMoE-1B-7B-0924-Instruct"
|
model_id = "./llms/OLMoE-1B-7B-0924-Instruct"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
log_file = Path(f"olmoe_{datetime.now().strftime("%Y%m%d-%H%M%S")}.parquet")
|
log_file = Path(f"logs/olmoe_{datetime.now().strftime("%Y%m%d-%H%M%S")}.parquet")
|
||||||
if log_file.exists():
|
if log_file.exists():
|
||||||
log_file.unlink()
|
log_file.unlink()
|
||||||
|
|
||||||
@ -32,6 +30,7 @@ try:
|
|||||||
# tensor_parallel_size=2,
|
# tensor_parallel_size=2,
|
||||||
gpu_memory_utilization=0.95,
|
gpu_memory_utilization=0.95,
|
||||||
max_model_len=4096,
|
max_model_len=4096,
|
||||||
|
max_num_seqs=1,
|
||||||
# compilation_config=CompilationConfig(
|
# compilation_config=CompilationConfig(
|
||||||
# level=CompilationLevel.PIECEWISE,
|
# level=CompilationLevel.PIECEWISE,
|
||||||
# # By default, it goes up to max_num_seqs
|
# # By default, it goes up to max_num_seqs
|
||||||
@ -42,21 +41,21 @@ try:
|
|||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=0.6,
|
temperature=0.6,
|
||||||
top_p=0.95,
|
# top_p=0.95,
|
||||||
top_k=20,
|
# top_k=20,
|
||||||
|
top_p=1.0,
|
||||||
|
top_k=1,
|
||||||
max_tokens=1024,
|
max_tokens=1024,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare the input to the model
|
# Prepare the input to the model
|
||||||
prompt = "Give me a very short introduction to large language models."
|
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": prompt},
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What is the name of the tallest building in Paris? Output the final answer ONLY:",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
# messages = [
|
|
||||||
# {"role": "system", "content": "你是一位人工智能助手。"},
|
|
||||||
# {"role": "user", "content": "请简要地介绍什么是大语言模型。"},
|
|
||||||
# ]
|
|
||||||
|
|
||||||
# Generate outputs
|
# Generate outputs
|
||||||
outputs = llm.chat(
|
outputs = llm.chat(
|
||||||
@ -73,12 +72,16 @@ try:
|
|||||||
# print("=== COMPLETION ===")
|
# print("=== COMPLETION ===")
|
||||||
print(out.outputs[0].text)
|
print(out.outputs[0].text)
|
||||||
print("\n---\n")
|
print("\n---\n")
|
||||||
dlog.log({
|
dlog.log(
|
||||||
|
{
|
||||||
"_time": datetime.now(),
|
"_time": datetime.now(),
|
||||||
"output_text": out.outputs[0].text
|
"src": "main",
|
||||||
})
|
"output_text": out.outputs[0].text,
|
||||||
|
"token_ids": out.outputs[0].token_ids,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
print("Finish completion")
|
print("\n---\nFinish completion")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
Reference in New Issue
Block a user