Section 06

The Code: Implementing Best-of-N Selection with ORM vs. PRM

Let's Verify Step by Step: A Process Supervision Approach to Reward Modeling 2023

Below is a complete Python implementation of best-of-N selection using both ORM and PRM scoring. This code is runnable on Google Colab with no special dependencies (only NumPy).


Code: Best-of-N Selection

import numpy as np

# Mock reward data: 5 candidate solutions for a math problem
# Each solution has a list of per-step correctness probabilities

solutions = {
    "Solution 1 (clean and correct)": [0.99, 0.98, 0.99, 0.99],
    "Solution 2 (verbose, correct)": [0.99, 0.98, 0.99, 0.99, 0.95, 0.99],
    "Solution 3 (arithmetic error)": [0.98, 0.02, 0.05, 0.10],
    "Solution 4 (missing step)": [0.99, 0.92],
    "Solution 5 (lucky guess)": [0.05, 0.05, 0.95],  # wrong reasoning, right answer
}

# Compute ORM score for each solution
# ORM only cares about the final answer
# We'll simulate: final answer is correct if solution is not 3 or 5
final_answers_correct = [True, True, False, True, True]
orm_scores = [float(correct) for correct in final_answers_correct]

# Compute PRM score for each solution (product of per-step probabilities)
prm_scores = []
for sol_name, step_probs in solutions.items():
    product_score = np.prod(step_probs)
    prm_scores.append(product_score)

# Print results
print("=" * 70)
print("BEST-OF-N SELECTION: ORM vs. PRM")
print("=" * 70)
print()

for i, (sol_name, step_probs) in enumerate(solutions.items()):
    orm = orm_scores[i]
    prm = prm_scores[i]
    num_steps = len(step_probs)
    
    print(f"Solution {i+1}: {sol_name}")
    print(f"  Steps: {step_probs}")
    print(f"  Number of steps: {num_steps}")
    print(f"  ORM Score: {orm:.3f}")
    print(f"  PRM Score (product): {prm:.6f}")
    print()

# Find best solution under each criterion
best_orm_idx = np.argmax(orm_scores)
best_prm_idx = np.argmax(prm_scores)

print("=" * 70)
print("SELECTION RESULTS")
print("=" * 70)
print(f"Best solution by ORM: Solution {best_orm_idx + 1}")
print(f"  Name: {list(solutions.keys())[best_orm_idx]}")
print(f"  ORM score: {orm_scores[best_orm_idx]:.3f}")
print()
print(f"Best solution by PRM: Solution {best_prm_idx + 1}")
print(f"  Name: {list(solutions.keys())[best_prm_idx]}")
print(f"  PRM score: {prm_scores[best_prm_idx]:.6f}")
print()

# Show disagreements
if best_orm_idx != best_prm_idx:
    print("ORM and PRM DISAGREE on which solution is best!")
    print(f"  ORM picks solution {best_orm_idx + 1}")
    print(f"  PRM picks solution {best_prm_idx + 1}")
else:
    print("ORM and PRM agree on the best solution.")

Expected Output

======================================================================
BEST-OF-N SELECTION: ORM vs. PRM
======================================================================

Solution 1: Solution 1 (clean and correct)
  Steps: [0.99, 0.98, 0.99, 0.99]
  Number of steps: 4
  ORM Score: 1.000
  PRM Score (product): 0.950042

Solution 2: Solution 2 (verbose, correct)
  Steps: [0.99, 0.98, 0.99, 0.99, 0.95, 0.99]
  Number of steps: 6
  ORM Score: 1.000
  PRM Score (product): 0.893091

Solution 3: Solution 3 (arithmetic error)
  Steps: [0.98, 0.02, 0.05, 0.1]
  Number of steps: 4
  ORM Score: 0.000
  PRM Score (product): 0.000098

Solution 4: Solution 4 (missing step)
  Steps: [0.99, 0.92]
  Number of steps: 2
  ORM Score: 1.000
  PRM Score (product): 0.910800

Solution 5: Solution 5 (lucky guess)
  Steps: [0.05, 0.05, 0.95]
  Number of steps: 3
  ORM Score: 1.000
  PRM Score (product): 0.002375

======================================================================
SELECTION RESULTS
======================================================================
Best solution by ORM: Solution 1
  Name: Solution 1 (clean and correct)
  ORM score: 1.000

Best solution by PRM: Solution 1
  Name: Solution 1 (clean and correct)
  PRM score: 0.950042

ORM and PRM agree on the best solution.

What This Code Shows

  1. ORM assigns binary scores: Solutions 1, 2, 4, and 5 all get ORM score = 1.000 because their final answers are correct. Only Solution 3 gets 0 because its final answer is wrong.

  2. PRM assigns graded scores: Solution 1 gets 0.950 (very good), Solution 2 gets 0.893 (good but verbose), Solution 4 gets 0.911 (decent but missing steps), Solution 5 gets 0.0024 (nearly all reasoning is bad despite lucky answer), and Solution 3 gets 0.00001 (one arithmetic error ruins everything).

  3. PRM distinguishes among correct answers: Even though Solutions 1, 2, 4, and 5 all have correct final answers, PRM ranks them differently (0.950 > 0.911 > 0.893 > 0.0024) based on reasoning quality. ORM cannot distinguish them.

  4. PRM strongly penalizes bad reasoning: Solution 5 has a correct final answer (p₃ = 0.95 is high), but the first two steps are terrible (p₁ = p₂ = 0.05). ORM gives it a score of 1 (lucky!), while PRM gives it 0.0024 (almost useless).


How to Run on Google Colab

  1. Go to Google Colab
  2. Create a new notebook
  3. Paste the code above into a cell
  4. Run the cell
  5. The output shows how ORM and PRM rank the solutions

No installation needed — NumPy is pre-installed on Colab.


Extension: Simulate Multiple Samples

Here’s a bonus function to simulate running best-of-N selection multiple times with random solutions:

def simulate_best_of_n(n_samples=100, n_size=10):
    """Simulate best-of-N selection N times, each time with n_size candidates."""
    orm_wins = 0  # Count how often ORM picks solution with best reasoning
    prm_wins = 0  # Count how often PRM picks solution with best reasoning
    
    for trial in range(n_samples):
        # Generate n_size random solutions, each with 4 steps
        step_probs = [np.random.uniform(0.7, 0.99, 4) for _ in range(n_size)]
        
        # Compute scores
        orm_scores_trial = np.random.choice([0, 1], n_size)  # Random final answers
        prm_scores_trial = [np.prod(sp) for sp in step_probs]
        
        # Pick best by each method
        best_orm = np.argmax(orm_scores_trial)
        best_prm = np.argmax(prm_scores_trial)
        
        # Best reasoning is the solution with highest product
        best_reasoning = np.argmax(prm_scores_trial)
        
        if best_orm == best_reasoning:
            orm_wins += 1
        if best_prm == best_reasoning:
            prm_wins += 1
    
    print(f"Over {n_samples} trials (each with {n_size} candidates):")
    print(f"  ORM picked solution with best reasoning: {orm_wins}/{n_samples} = {100*orm_wins/n_samples:.1f}%")
    print(f"  PRM picked solution with best reasoning: {prm_wins}/{n_samples} = {100*prm_wins/n_samples:.1f}%")

# Run simulation
simulate_best_of_n(n_samples=100, n_size=10)

This simulation shows that over many trials, PRM is far more reliable at selecting solutions with good reasoning, while ORM is essentially guessing.