Implement OlMoE logits logging.

Update: logging items in OlMoE.

Update: some changes for main script.

Add: logits post processing notebook.
This commit is contained in:
2025-09-29 19:52:44 +08:00
parent 6738855fb1
commit 5d1d818138
5 changed files with 99 additions and 22 deletions

View File

@ -102,14 +102,18 @@ class OlmoeMoE(nn.Module):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
dlog.log({
"_time": datetime.now(),
"router_logits": router_logits.cpu().float().numpy(),
"layer": self.layer_idx,
})
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
dlog.log({
"_time": datetime.now(),
"src": "router",
"layer": self.layer_idx,
"router_logits": router_logits.cpu().float().numpy(),
"orig_shape": list(orig_shape),
"hidden_dim": hidden_dim,
"hidden_states_shape": list(hidden_states.shape),
})
return final_hidden_states.view(orig_shape)
def add_logging_metrics(self, layer_idx: int):
@ -501,6 +505,11 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
dlog.log({
"_time": datetime.now(),
"src": "lm_logit",
"logits": logits.cpu().float().numpy(),
})
return logits
def load_weights(self, weights: Iterable[tuple[str,