# --- # jupyter: # jupytext: # formats: ipynb,py:percent # text_representation: # extension: .py # format_name: percent # format_version: '1.3' # jupytext_version: 1.17.3 # kernelspec: # display_name: venv # language: python # name: python3 # --- # %% import numpy as np import pandas as pd # import torch # %% log_file = "olmoe_20250929-165403.parquet" model_id = "./llms/OLMoE-1B-7B-0924-Instruct" # %% df = pd.read_parquet(log_file) df.head() # %% # logit = df.loc[1, "router_logits"] # logit # %% main_mask = (df["src"] == "main") token_ids = df.loc[main_mask, "token_ids"].item() output_text = df.loc[main_mask, "output_text"].item() # %% lm_mask = (df["src"] == "lm_logit") df.loc[lm_mask, "logits"] = df.loc[lm_mask, "logits"].apply(lambda arr: np.stack([a for a in arr]).flatten()) df.loc[lm_mask, "token_id"] = df.loc[lm_mask, "logits"].apply(lambda l: np.argmax(l, axis=-1)) df.head() # %% df[lm_mask, "token_id"].to_numpy() # %% token_ids # %% import transformers from transformers import AutoTokenizer, GPTNeoXTokenizerFast tokenizer: GPTNeoXTokenizerFast = AutoTokenizer.from_pretrained(model_id) # %% tokenizer.decode(token_ids)