Basic script for OLMoE inference.

Add: basic script for OLMoE inference and saving expert logits.

Add: OLMoE vLLM module with logits saving.

Add: vLLM model registering.

Fix: update DataLogger module for multiprocessing.
This commit is contained in:
2025-09-28 11:41:30 +08:00
parent 5942bbfd05
commit 6738855fb1
4 changed files with 470 additions and 374 deletions

View File

@ -17,6 +17,7 @@
This file is origin from vllm/model_executor/models/olmoe.py
"""
from collections.abc import Iterable
from datetime import datetime
from functools import partial
from typing import Any, Optional, Union
@ -51,7 +52,7 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, is_pp_missing_p
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from utils import DataLogger
from utils import DataLogger as dlog
logger = init_logger(__name__)
@ -100,10 +101,20 @@ class OlmoeMoE(nn.Module):
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
dlog.log({
"_time": datetime.now(),
"router_logits": router_logits.cpu().float().numpy(),
"layer": self.layer_idx,
})
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
return final_hidden_states.view(orig_shape)
def add_logging_metrics(self, layer_idx: int):
self.layer_idx = layer_idx
class OlmoeAttention(nn.Module):
@ -234,6 +245,16 @@ class OlmoeDecoderLayer(nn.Module):
prefix=f"{prefix}.self_attn",
)
# Extract layer_idx from prefix
self.layer_idx = None
try:
# Prefix format: "model.layers.7"
parts = prefix.split('.')
if len(parts) >= 3 and parts[2].isdigit():
self.layer_idx = int(parts[2])
except ValueError:
pass
self.mlp = OlmoeMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
@ -245,6 +266,8 @@ class OlmoeDecoderLayer(nn.Module):
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
self.mlp.add_logging_metrics(layer_idx=self.layer_idx)
def forward(
self,
positions: torch.Tensor,

31
models/register.py Normal file
View File

@ -0,0 +1,31 @@
from vllm import ModelRegistry
def register_vllm_logit_logging_models():
models_mapper = {"OlmoeForCausalLM": "models.log_expert.olmoe:OlmoeForCausalLM"}
for model_name, model_path in models_mapper.items():
ModelRegistry.register_model(model_name, model_path)
# ModelRegistry.register_model(
# "Qwen3MoeForCausalLM", "src.modeling_vllm_save.qwen3_moe:Qwen3MoeForCausalLM"
# )
# ModelRegistry.register_model(
# "MixtralForCausalLM", "src.modeling_vllm_save.mixtral:MixtralForCausalLM"
# )
# ModelRegistry.register_model(
# "OlmoeForCausalLM", "models.log_expert.olmoe:OlmoeForCausalLM"
# )
# ModelRegistry.register_model(
# "DeepseekV2ForCausalLM",
# "src.modeling_vllm_save.deepseek_v2:DeepseekV2ForCausalLM",
# )
# ModelRegistry.register_model(
# "Llama4ForConditionalGeneration",
# "src.modeling_vllm_save.mllama4:Llama4ForConditionalGeneration",
# )
# ModelRegistry.register_model(
# "GptOssForCausalLM", "src.modeling_vllm_save.gpt_oss:GptOssForCausalLM"
# )
# ModelRegistry.register_model(
# "PhiMoEForCausalLM", "src.modeling_vllm_save.phimoe:PhiMoEForCausalLM"
# )