Initial commit - WhyRating Engine (Google Reviews Scraper)
This commit is contained in:
486
packages/reviewiq-pipeline/validate_router.py
Normal file
486
packages/reviewiq-pipeline/validate_router.py
Normal file
@@ -0,0 +1,486 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Validate router decisions against real reviews with minimal LLM cost.
|
||||
|
||||
This script:
|
||||
1. Loads real reviews from database
|
||||
2. Routes them through the router
|
||||
3. Cherry-picks samples from each tier for validation
|
||||
4. Optionally runs LLM on small samples to validate decisions
|
||||
|
||||
Usage:
|
||||
# Dry run - just show routing decisions, no LLM calls
|
||||
python validate_router.py <job_id> --dry-run
|
||||
|
||||
# Validate with LLM (costs ~$0.05-0.10)
|
||||
python validate_router.py <job_id> --validate
|
||||
|
||||
# Custom sample sizes
|
||||
python validate_router.py <job_id> --validate --skip-samples=3 --cheap-samples=5 --full-samples=3
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger("validate_router")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of validating a single review."""
|
||||
review_id: str
|
||||
text: str
|
||||
rating: int
|
||||
routed_tier: str
|
||||
routing_reason: str
|
||||
routing_signals: dict
|
||||
# LLM results (if validated)
|
||||
llm_urt: str | None = None
|
||||
llm_valence: str | None = None
|
||||
llm_span_count: int | None = None
|
||||
llm_cost: float | None = None
|
||||
# Validation verdict
|
||||
routing_correct: bool | None = None
|
||||
notes: str = ""
|
||||
|
||||
|
||||
async def load_reviews_from_db(job_id: str, database_url: str) -> list[dict]:
|
||||
"""Load reviews from database for a job."""
|
||||
import asyncpg
|
||||
|
||||
conn = await asyncpg.connect(database_url)
|
||||
try:
|
||||
# Get reviews with text from pipeline schema
|
||||
rows = await conn.fetch("""
|
||||
SELECT
|
||||
re.review_id,
|
||||
re.text,
|
||||
re.rating,
|
||||
re.business_id,
|
||||
re.place_id
|
||||
FROM pipeline.reviews_enriched re
|
||||
WHERE re.job_id = $1::uuid
|
||||
AND re.text IS NOT NULL
|
||||
AND re.text != ''
|
||||
ORDER BY re.id
|
||||
""", job_id)
|
||||
|
||||
reviews = []
|
||||
for row in rows:
|
||||
text = row["text"] or ""
|
||||
reviews.append({
|
||||
"review_id": row["review_id"],
|
||||
"text": text,
|
||||
"text_normalized": text.lower().strip(),
|
||||
"rating": row["rating"],
|
||||
"business_id": row["business_id"],
|
||||
"place_id": row["place_id"],
|
||||
"source": "google",
|
||||
"review_version": 1,
|
||||
"review_time": "2024-01-01T00:00:00Z",
|
||||
})
|
||||
|
||||
logger.info(f"Loaded {len(reviews)} reviews from job {job_id}")
|
||||
return reviews
|
||||
|
||||
finally:
|
||||
await conn.close()
|
||||
|
||||
|
||||
def route_reviews(reviews: list[dict]) -> dict[str, list[dict]]:
|
||||
"""Route reviews and return grouped by tier."""
|
||||
from reviewiq_pipeline.services.review_router import (
|
||||
ReviewRouter,
|
||||
RoutingTier,
|
||||
create_router,
|
||||
)
|
||||
|
||||
router = create_router(conservative=True)
|
||||
routed = router.route_batch(reviews)
|
||||
|
||||
return {
|
||||
"skip": routed[RoutingTier.SKIP],
|
||||
"cheap": routed[RoutingTier.CHEAP_MODEL],
|
||||
"full": routed[RoutingTier.FULL_MODEL],
|
||||
}
|
||||
|
||||
|
||||
def select_diverse_samples(
|
||||
reviews: list[dict],
|
||||
tier: str,
|
||||
n_samples: int,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Select diverse samples from a tier for validation.
|
||||
|
||||
Strategy:
|
||||
- For SKIP: Pick different ratings, different lengths
|
||||
- For CHEAP: Pick different word counts, different ratings
|
||||
- For FULL: Pick different routing reasons
|
||||
"""
|
||||
if not reviews or n_samples <= 0:
|
||||
return []
|
||||
|
||||
samples = []
|
||||
seen_reasons = set()
|
||||
seen_ratings = set()
|
||||
|
||||
# First pass: get diversity by reason and rating
|
||||
for review in reviews:
|
||||
routing = review.get("_routing")
|
||||
if not routing:
|
||||
continue
|
||||
|
||||
reason = routing.reason
|
||||
rating = review["rating"]
|
||||
|
||||
# Prioritize diversity
|
||||
key = (reason, rating)
|
||||
if key not in seen_reasons or len(samples) < n_samples:
|
||||
if len(samples) < n_samples:
|
||||
samples.append(review)
|
||||
seen_reasons.add(key)
|
||||
seen_ratings.add(rating)
|
||||
|
||||
# Fill remaining slots if needed
|
||||
for review in reviews:
|
||||
if len(samples) >= n_samples:
|
||||
break
|
||||
if review not in samples:
|
||||
samples.append(review)
|
||||
|
||||
return samples[:n_samples]
|
||||
|
||||
|
||||
def print_routing_summary(routed: dict[str, list[dict]]):
|
||||
"""Print summary of routing decisions."""
|
||||
total = sum(len(v) for v in routed.values())
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("ROUTING SUMMARY")
|
||||
print("=" * 70)
|
||||
|
||||
for tier, reviews in routed.items():
|
||||
pct = len(reviews) / total * 100 if total > 0 else 0
|
||||
print(f"\n{tier.upper()} TIER: {len(reviews)} reviews ({pct:.1f}%)")
|
||||
|
||||
# Group by reason
|
||||
reasons = {}
|
||||
for r in reviews:
|
||||
routing = r.get("_routing")
|
||||
if routing:
|
||||
reason = routing.reason
|
||||
reasons[reason] = reasons.get(reason, 0) + 1
|
||||
|
||||
for reason, count in sorted(reasons.items(), key=lambda x: -x[1]):
|
||||
print(f" - {reason}: {count}")
|
||||
|
||||
|
||||
def print_samples(samples: list[dict], tier: str):
|
||||
"""Print sample reviews for inspection."""
|
||||
print(f"\n{'=' * 70}")
|
||||
print(f"{tier.upper()} TIER SAMPLES ({len(samples)} reviews)")
|
||||
print("=" * 70)
|
||||
|
||||
for i, review in enumerate(samples, 1):
|
||||
routing = review.get("_routing")
|
||||
signals = routing.signals if routing else {}
|
||||
|
||||
print(f"\n[{i}] Review ID: {review['review_id']}")
|
||||
print(f" Rating: {'⭐' * review['rating']}")
|
||||
print(f" Text: \"{review['text'][:100]}{'...' if len(review['text']) > 100 else ''}\"")
|
||||
print(f" Routing: {routing.reason if routing else 'N/A'}")
|
||||
print(f" Signals: words={signals.get('word_count', '?')}, "
|
||||
f"chars={signals.get('char_count', '?')}, "
|
||||
f"numbers={signals.get('has_numbers', '?')}, "
|
||||
f"sentences={signals.get('sentence_count', '?')}")
|
||||
|
||||
|
||||
async def validate_with_llm(
|
||||
samples: list[dict],
|
||||
tier: str,
|
||||
config: Any,
|
||||
) -> list[ValidationResult]:
|
||||
"""
|
||||
Run LLM classification on samples to validate routing decisions.
|
||||
|
||||
Returns validation results with verdicts.
|
||||
"""
|
||||
from reviewiq_pipeline.services.llm_client import LLMClient, BatchReviewInput, PartialBatchResult
|
||||
|
||||
results = []
|
||||
|
||||
if not samples:
|
||||
return results
|
||||
|
||||
# Create LLM client
|
||||
client = LLMClient.create(config)
|
||||
|
||||
try:
|
||||
# Prepare batch input
|
||||
batch_input = [
|
||||
BatchReviewInput(
|
||||
review_id=r["review_id"],
|
||||
text=r["text"],
|
||||
rating=r["rating"],
|
||||
)
|
||||
for r in samples
|
||||
]
|
||||
|
||||
# Run classification
|
||||
logger.info(f"Running LLM on {len(samples)} {tier} tier samples...")
|
||||
|
||||
llm_responses = []
|
||||
metadata = {}
|
||||
|
||||
try:
|
||||
llm_responses, metadata = await client.classify_batch(batch_input, "standard")
|
||||
except PartialBatchResult as e:
|
||||
# Handle partial results
|
||||
logger.warning(f"Partial result for {tier} tier: {len(e.partial_results)} recovered")
|
||||
metadata = e.metadata or {}
|
||||
|
||||
# Build responses from partial results
|
||||
for partial in e.partial_results:
|
||||
idx = partial.get("review_index", -1)
|
||||
if 0 <= idx < len(samples):
|
||||
llm_responses.append({
|
||||
"spans": partial.get("spans", []),
|
||||
"review_summary": partial.get("review_summary", {}),
|
||||
"_index": idx,
|
||||
})
|
||||
|
||||
# Pad with empty responses for missing indices
|
||||
processed_indices = {r.get("_index", -1) for r in llm_responses}
|
||||
for i, sample in enumerate(samples):
|
||||
if i not in processed_indices:
|
||||
llm_responses.append({
|
||||
"spans": [],
|
||||
"review_summary": {},
|
||||
"_index": i,
|
||||
"_error": "partial_recovery_failed",
|
||||
})
|
||||
|
||||
# Sort by original index
|
||||
llm_responses.sort(key=lambda x: x.get("_index", 999))
|
||||
|
||||
cost = metadata.get("cost_usd", 0)
|
||||
logger.info(f"LLM cost for {tier} tier: ${cost:.4f}")
|
||||
|
||||
# Process results
|
||||
for review, llm_response in zip(samples, llm_responses):
|
||||
routing = review.get("_routing")
|
||||
signals = routing.signals if routing else {}
|
||||
|
||||
spans = llm_response.get("spans", [])
|
||||
primary_span = next((s for s in spans if s.get("is_primary")), spans[0] if spans else {})
|
||||
|
||||
urt = primary_span.get("urt_primary", "N/A")
|
||||
valence = primary_span.get("valence", "N/A")
|
||||
|
||||
# Determine if routing was correct
|
||||
routing_correct = None
|
||||
notes = ""
|
||||
|
||||
if tier == "skip":
|
||||
# SKIP is correct if LLM gives generic code (V4.03) or single low-info span
|
||||
is_generic = urt in ("V4.03", "V4.01", "V4.02", "O1.01")
|
||||
is_simple = len(spans) == 1 and primary_span.get("intensity") == "I1"
|
||||
routing_correct = is_generic or is_simple
|
||||
if not routing_correct:
|
||||
notes = f"LLM found specific content: {urt}"
|
||||
else:
|
||||
notes = "Correctly skipped (generic/simple)"
|
||||
|
||||
elif tier == "cheap":
|
||||
# CHEAP is correct if classification is straightforward
|
||||
# (single domain, no complex causal chains)
|
||||
is_simple = len(spans) <= 2
|
||||
routing_correct = is_simple
|
||||
if not routing_correct:
|
||||
notes = f"Complex: {len(spans)} spans found"
|
||||
else:
|
||||
notes = "Simple enough for cheap model"
|
||||
|
||||
elif tier == "full":
|
||||
# FULL is correct if there's meaningful content
|
||||
has_content = len(spans) >= 1 and urt not in ("V4.03", "O1.01")
|
||||
routing_correct = has_content
|
||||
if routing_correct:
|
||||
notes = f"Correctly sent to full: {len(spans)} spans, {urt}"
|
||||
else:
|
||||
notes = "Could have been cheaper"
|
||||
|
||||
result = ValidationResult(
|
||||
review_id=review["review_id"],
|
||||
text=review["text"],
|
||||
rating=review["rating"],
|
||||
routed_tier=tier,
|
||||
routing_reason=routing.reason if routing else "N/A",
|
||||
routing_signals=signals,
|
||||
llm_urt=urt,
|
||||
llm_valence=valence,
|
||||
llm_span_count=len(spans),
|
||||
llm_cost=cost / len(samples),
|
||||
routing_correct=routing_correct,
|
||||
notes=notes,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def print_validation_results(results: list[ValidationResult], tier: str):
|
||||
"""Print validation results."""
|
||||
if not results:
|
||||
return
|
||||
|
||||
print(f"\n{'=' * 70}")
|
||||
print(f"{tier.upper()} TIER VALIDATION RESULTS")
|
||||
print("=" * 70)
|
||||
|
||||
correct = sum(1 for r in results if r.routing_correct)
|
||||
total = len(results)
|
||||
accuracy = correct / total * 100 if total > 0 else 0
|
||||
|
||||
print(f"\nAccuracy: {correct}/{total} ({accuracy:.1f}%)")
|
||||
|
||||
for r in results:
|
||||
status = "✅" if r.routing_correct else "❌"
|
||||
print(f"\n{status} [{r.review_id}] \"{r.text[:60]}...\"")
|
||||
print(f" Rating: {r.rating}, Routed: {r.routed_tier} ({r.routing_reason})")
|
||||
print(f" LLM: URT={r.llm_urt}, Valence={r.llm_valence}, Spans={r.llm_span_count}")
|
||||
print(f" Notes: {r.notes}")
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description="Validate router decisions")
|
||||
parser.add_argument("job_id", help="Job ID to analyze")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Show routing only, no LLM")
|
||||
parser.add_argument("--validate", action="store_true", help="Run LLM validation")
|
||||
parser.add_argument("--skip-samples", type=int, default=3, help="SKIP tier samples")
|
||||
parser.add_argument("--cheap-samples", type=int, default=5, help="CHEAP tier samples")
|
||||
parser.add_argument("--full-samples", type=int, default=3, help="FULL tier samples")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Database URL
|
||||
database_url = os.environ.get(
|
||||
"DATABASE_URL",
|
||||
"postgresql://scraper:scraper123@localhost:5437/scraper"
|
||||
)
|
||||
|
||||
# Load reviews
|
||||
reviews = await load_reviews_from_db(args.job_id, database_url)
|
||||
if not reviews:
|
||||
print("No reviews found for job")
|
||||
return
|
||||
|
||||
# Route reviews
|
||||
routed = route_reviews(reviews)
|
||||
|
||||
# Print summary
|
||||
print_routing_summary(routed)
|
||||
|
||||
# Select samples
|
||||
skip_samples = select_diverse_samples(routed["skip"], "skip", args.skip_samples)
|
||||
cheap_samples = select_diverse_samples(routed["cheap"], "cheap", args.cheap_samples)
|
||||
full_samples = select_diverse_samples(routed["full"], "full", args.full_samples)
|
||||
|
||||
# Print samples
|
||||
print_samples(skip_samples, "skip")
|
||||
print_samples(cheap_samples, "cheap")
|
||||
print_samples(full_samples, "full")
|
||||
|
||||
# Estimate cost
|
||||
total_samples = len(skip_samples) + len(cheap_samples) + len(full_samples)
|
||||
estimated_cost = total_samples * 0.003 # ~$0.003 per review with Sonnet
|
||||
print(f"\n{'=' * 70}")
|
||||
print(f"VALIDATION COST ESTIMATE: ~${estimated_cost:.3f} for {total_samples} samples")
|
||||
print("=" * 70)
|
||||
|
||||
if args.dry_run:
|
||||
print("\n[DRY RUN] No LLM calls made. Use --validate to run validation.")
|
||||
return
|
||||
|
||||
if not args.validate:
|
||||
print("\nUse --validate to run LLM validation on these samples.")
|
||||
return
|
||||
|
||||
# Run validation
|
||||
from reviewiq_pipeline.config import Config
|
||||
|
||||
config = Config(
|
||||
database_url=database_url,
|
||||
llm_provider="anthropic",
|
||||
llm_model="claude-sonnet-4-5-20250929",
|
||||
anthropic_api_key=os.environ.get("ANTHROPIC_API_KEY",
|
||||
"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"),
|
||||
)
|
||||
|
||||
all_results = []
|
||||
total_cost = 0
|
||||
|
||||
# Validate each tier
|
||||
for tier, samples in [("skip", skip_samples), ("cheap", cheap_samples), ("full", full_samples)]:
|
||||
if samples:
|
||||
results = await validate_with_llm(samples, tier, config)
|
||||
all_results.extend(results)
|
||||
total_cost += sum(r.llm_cost or 0 for r in results)
|
||||
print_validation_results(results, tier)
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'=' * 70}")
|
||||
print("VALIDATION SUMMARY")
|
||||
print("=" * 70)
|
||||
|
||||
for tier in ["skip", "cheap", "full"]:
|
||||
tier_results = [r for r in all_results if r.routed_tier == tier]
|
||||
if tier_results:
|
||||
correct = sum(1 for r in tier_results if r.routing_correct)
|
||||
total = len(tier_results)
|
||||
print(f"{tier.upper()}: {correct}/{total} correct ({correct/total*100:.0f}%)")
|
||||
|
||||
overall_correct = sum(1 for r in all_results if r.routing_correct)
|
||||
overall_total = len(all_results)
|
||||
print(f"\nOVERALL: {overall_correct}/{overall_total} correct ({overall_correct/overall_total*100:.0f}%)")
|
||||
print(f"TOTAL COST: ${total_cost:.4f}")
|
||||
|
||||
# Recommendations
|
||||
print(f"\n{'=' * 70}")
|
||||
print("RECOMMENDATIONS")
|
||||
print("=" * 70)
|
||||
|
||||
skip_errors = [r for r in all_results if r.routed_tier == "skip" and not r.routing_correct]
|
||||
if skip_errors:
|
||||
print("\n⚠️ SKIP tier false negatives found:")
|
||||
for r in skip_errors:
|
||||
print(f" - \"{r.text[:50]}...\" → {r.llm_urt}")
|
||||
print(" Consider tightening SKIP criteria")
|
||||
else:
|
||||
print("\n✅ SKIP tier looks safe")
|
||||
|
||||
cheap_errors = [r for r in all_results if r.routed_tier == "cheap" and not r.routing_correct]
|
||||
if cheap_errors:
|
||||
print("\n⚠️ CHEAP tier may miss complexity:")
|
||||
for r in cheap_errors:
|
||||
print(f" - \"{r.text[:50]}...\" → {r.llm_span_count} spans")
|
||||
else:
|
||||
print("\n✅ CHEAP tier thresholds look good")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user