Basic script for OLMoE inference.
Add: basic script for OLMoE inference and saving expert logits. Add: OLMoE vLLM module with logits saving. Add: vLLM model registering. Fix: update DataLogger module for multiprocessing.
This commit is contained in:
@ -17,6 +17,7 @@
|
||||
This file is origin from vllm/model_executor/models/olmoe.py
|
||||
"""
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@ -51,7 +52,7 @@ from vllm.model_executor.models.utils import (AutoWeightsLoader, is_pp_missing_p
|
||||
make_empty_intermediate_tensors_factory, make_layers,
|
||||
maybe_prefix)
|
||||
|
||||
from utils import DataLogger
|
||||
from utils import DataLogger as dlog
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -100,10 +101,20 @@ class OlmoeMoE(nn.Module):
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
dlog.log({
|
||||
"_time": datetime.now(),
|
||||
"router_logits": router_logits.cpu().float().numpy(),
|
||||
"layer": self.layer_idx,
|
||||
})
|
||||
|
||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
def add_logging_metrics(self, layer_idx: int):
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
|
||||
class OlmoeAttention(nn.Module):
|
||||
|
||||
@ -234,6 +245,16 @@ class OlmoeDecoderLayer(nn.Module):
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
# Extract layer_idx from prefix
|
||||
self.layer_idx = None
|
||||
try:
|
||||
# Prefix format: "model.layers.7"
|
||||
parts = prefix.split('.')
|
||||
if len(parts) >= 3 and parts[2].isdigit():
|
||||
self.layer_idx = int(parts[2])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
self.mlp = OlmoeMoE(
|
||||
num_experts=config.num_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
@ -245,6 +266,8 @@ class OlmoeDecoderLayer(nn.Module):
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
|
||||
|
||||
self.mlp.add_logging_metrics(layer_idx=self.layer_idx)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
|
Reference in New Issue
Block a user