Two Python implementations that show the core mechanics of rStar-Math. Both run free on Google Colab.
Block 1: MCTS UCB Selection
Implement the Upper Confidence Bound formula and simulate one round of MCTS node selection.
import math
class MCTSNode:
"""Represents one node in the MCTS search tree (a partial solution)."""
def __init__(self, step_description, parent=None):
self.step = step_description # what solution step this node represents
self.parent = parent # parent node in tree
self.children = [] # child nodes (next steps)
self.visits = 0 # N(v): times this node has been visited
self.total_reward = 0.0 # Q(v): sum of rewards from rollouts
def ucb_score(self, exploration_const=math.sqrt(2)):
"""
Compute UCB = Q(v)/N(v) + C * sqrt(ln(N(parent)) / N(v))
Returns inf for unvisited nodes (always explore first).
"""
if self.visits == 0:
return float('inf') # unvisited nodes have infinite UCB
avg_reward = self.total_reward / self.visits
parent_visits = self.parent.visits if self.parent else self.visits
exploration_bonus = exploration_const * math.sqrt(
math.log(parent_visits) / self.visits
)
return avg_reward + exploration_bonus
def select_best_child(node):
"""Return the child node with highest UCB score."""
return max(node.children, key=lambda c: c.ucb_score())
# Simulate: MCTS has explored a problem and has 3 candidate next steps
root = MCTSNode("Problem: count integers 1-99 divisible by 3 but not 5")
root.visits = 20 # root has been visited 20 times in exploration
# Three candidate approaches: careful count, formula, list comprehension
approaches = [
"Count multiples of 3 using range(3, 100, 3)",
"Use arithmetic sequence formula",
"List comprehension: [x for x in range(1,100) if x%3==0]"
]
histories = [
(16.0, 19), # (total_reward, visits) — cautious, heavily explored
(4.2, 6), # — moderate, less explored
(0.8, 2), # — bold, barely tried
]
for approach, (reward, visits) in zip(approaches, histories):
child = MCTSNode(approach, parent=root)
child.visits = visits
child.total_reward = reward
root.children.append(child)
# Show UCB scores
print("MCTS Node Selection (UCB Scores):")
print("=" * 70)
for i, child in enumerate(root.children, 1):
ucb = child.ucb_score()
avg = child.total_reward / child.visits
print(f"\nApproach {i}: {child.step[:45]}...")
print(f" Visits: {child.visits:2d} | Total Reward: {child.total_reward:5.1f}")
print(f" Avg Reward: {avg:.3f} | Exploration Bonus: {ucb - avg:.3f}")
print(f" UCB Score: {ucb:.3f}")
selected = select_best_child(root)
print(f"\n{'=' * 70}")
print(f"MCTS selects: Approach with UCB = {selected.ucb_score():.3f}")
print(f"Step: {selected.step}")
print("\nWhy? The bold approach hasn't been tried much (2 visits).")
print("Exploration bonus outweighs lower average reward.")
What you see: Despite the “cautious approach” having a higher average reward (16/19 = 0.84), MCTS selects the bold approach (UCB = 2.9 vs. 1.5) because it hasn’t been explored. This forces MCTS to sample different strategies, not just exploit the best-known one.
Key insight: UCB prevents MCTS from prematurely committing to a suboptimal strategy. It balances learning what works (exploitation) with trying new things (exploration).
Block 2: Program-of-Thought Auto-Verification
Generate Python code for a solution, execute it, and automatically verify correctness.
def execute_and_verify(solution_code: str, expected_answer) -> tuple:
"""
Execute a Python solution and verify if it produces expected_answer.
Returns (result, is_correct, error_msg).
"""
local_namespace = {}
error_msg = None
is_correct = False
result = None
try:
# Execute the generated code in an isolated namespace
exec(solution_code, {}, local_namespace)
# Extract the final 'answer' variable
result = local_namespace.get('answer', None)
# Check if it matches expected
is_correct = (result == expected_answer)
except Exception as e:
# Code failed: syntax error, runtime error, etc.
error_msg = f"Execution error: {str(e)[:50]}"
result = None
is_correct = False
return result, is_correct, error_msg
# Example 1: Correct solution (divisibility problem from Section 5)
correct_code = """
# Count integers 1-99 divisible by 3 but not 5
count_div3 = len([x for x in range(1, 100) if x % 3 == 0])
count_div15 = len([x for x in range(1, 100) if x % 15 == 0])
answer = count_div3 - count_div15
"""
result, is_correct, error = execute_and_verify(correct_code, 27)
print("Correct Solution:")
print(f" Result: {result} | Correct: {is_correct} | Error: {error}")
# Example 2: Incorrect solution (forgot to subtract multiples of 15)
incomplete_code = """
count_div3 = len([x for x in range(1, 100) if x % 3 == 0])
answer = count_div3 # forgot to subtract!
"""
result, is_correct, error = execute_and_verify(incomplete_code, 27)
print("\nIncomplete Solution (forgot to exclude multiples of 5):")
print(f" Result: {result} | Correct: {is_correct} | Error: {error}")
# Example 3: Syntax error (code doesn't run at all)
broken_code = """
count = len([x for x in range(1, 100) if x % 3 == 0) # missing bracket!
answer = count
"""
result, is_correct, error = execute_and_verify(broken_code, 27)
print("\nBroken Syntax:")
print(f" Result: {result} | Correct: {is_correct}")
print(f" Error: {error}")
# Show the filtering in data collection
print("\n" + "=" * 70)
print("Data Collection Simulation:")
print("=" * 70)
solutions = [
(correct_code, "Round 1 candidate 1 (high-quality)", 0.95),
(incomplete_code, "Round 1 candidate 2 (low-quality)", 0.60),
(broken_code, "Round 1 candidate 3 (broken)", 0.10),
]
collected_count = 0
for code, description, prm_score in solutions:
result, is_correct, _ = execute_and_verify(code, 27)
status = "KEEP" if (is_correct and prm_score > 0.7) else "DISCARD"
if status == "KEEP":
collected_count += 1
print(f"{description:40s} | Correct: {is_correct} | PRM: {prm_score} | {status}")
print(f"\nAfter filtering: {collected_count} high-quality solutions collected")
print("(Only these are used for supervised fine-tuning)")
What you see:
- The correct solution returns 27 ✓
- The incomplete solution returns 33 (wrong) ✗
- The broken solution fails execution ✗
Data collection logic: Only keep solutions that (a) execute without error, (b) produce the correct answer, and (c) have high PRM scores. This ensures the training data is clean and high-quality.
Key Takeaways
MCTS UCB (Block 1)
- Exploitation: Nodes with high average reward get selected often
- Exploration: Rarely-visited nodes get a bonus, forcing MCTS to try different strategies
- Balance: The constant C controls the exploration-exploitation tradeoff (higher C = more exploration)
Program-of-Thought (Block 2)
- Automatic verification: No humans needed to judge if solutions are correct
- Scalability: Can generate and verify thousands of solutions automatically
- Filtering: Combine execution verification with PRM scores to select high-quality training data
Why this enables rStar-Math:
- MCTS efficiently explores solution spaces using guided search
- Python verification gives automatic, scalable correctness checking
- Together, they generate clean training data without human annotation
- The model trains on this data and improves, enabling better search next round