YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

PRMs are trained to predict the correctness of each step on the positions of "\n\n" and "<eos>".

Usage:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "ScalableMath/llemma-7b-prm-metamath-level-1to3-hf"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/llemma_7b")

qa_example = """# Question

Convert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$

# Solution

To convert from rectangular coordinates to polar coordinates, we use the formulas $r = \sqrt{x^2 + y^2}$ and $\theta = \arctan\left(\frac{y}{x}\right)$.

In this case, $x = 0$ and $y = 3$, so $r = \sqrt{0^2 + 3^2} = 3$ and $\theta = \arctan\left(\frac{3}{0}\right)$.

Since $\frac{3}{0}$ is undefined, we can say that $\theta$ is undefined.
However, we know that $\theta$ is an angle, and since $r > 0$, we can say that $\theta$ is any angle that satisfies $0 \le \theta < 2 \pi$.

Therefore, the polar coordinates of the point $(0,3)$ are $\boxed{(3,\theta)}$, where $0 \le \theta < 2 \pi$.

# Answer

(3,\theta)"""

begin_solution_tokens = tokenizer.encode("\n\n# Solution", add_special_tokens=False)[1:]
scoring_tokens = tokenizer.encode("\n\n", add_special_tokens=False)[1:]
eos_token = tokenizer.eos_token_id

input_ids = tokenizer.encode(qa_example)

begin_solution_flag = False

candidate_positions = []

for start_idx in range(len(input_ids)):
    if tuple(input_ids[start_idx:start_idx+len(begin_solution_tokens)]) == tuple(begin_solution_tokens):
        begin_solution_flag = True

    if begin_solution_flag and tuple(input_ids[start_idx:start_idx+len(scoring_tokens)]) == tuple(scoring_tokens):
        candidate_positions.append(start_idx)

    if input_ids[start_idx] == eos_token:
        candidate_positions.append(start_idx)
        break

# maybe delete the first and the second to last candidate_positions
# because they are "\n\n" after "# Solution" and after "# Answer"
del candidate_positions[0]
del candidate_positions[-2]

input_tensor = torch.tensor([input_ids])
candidate_positions = torch.tensor(candidate_positions)

with torch.no_grad():
    logits = model(input_tensor).logits
    scores =logits.mean(dim=-1)
    step_scores = scores[0][candidate_positions]
    step_probs = torch.sigmoid(step_scores)

print(step_probs)

# tensor([0.7264, 0.8152, 0.7827, 0.4709, 0.5181])
Downloads last month
225
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Collection including ScalableMath/llemma-7b-prm-metamath-level-1to3-hf