Implement OlMoE logits logging.
Update: logging items in OlMoE. Update: some changes for main script. Add: logits post processing notebook.
This commit is contained in:
@ -102,14 +102,18 @@ class OlmoeMoE(nn.Module):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
dlog.log({
|
||||
"_time": datetime.now(),
|
||||
"router_logits": router_logits.cpu().float().numpy(),
|
||||
"layer": self.layer_idx,
|
||||
})
|
||||
|
||||
final_hidden_states = self.experts(hidden_states=hidden_states,
|
||||
router_logits=router_logits)
|
||||
|
||||
dlog.log({
|
||||
"_time": datetime.now(),
|
||||
"src": "router",
|
||||
"layer": self.layer_idx,
|
||||
"router_logits": router_logits.cpu().float().numpy(),
|
||||
"orig_shape": list(orig_shape),
|
||||
"hidden_dim": hidden_dim,
|
||||
"hidden_states_shape": list(hidden_states.shape),
|
||||
})
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
def add_logging_metrics(self, layer_idx: int):
|
||||
@ -501,6 +505,11 @@ class OlmoeForCausalLM(nn.Module, SupportsPP):
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
dlog.log({
|
||||
"_time": datetime.now(),
|
||||
"src": "lm_logit",
|
||||
"logits": logits.cpu().float().numpy(),
|
||||
})
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
|
Reference in New Issue
Block a user