524 lines
18 KiB
Python
524 lines
18 KiB
Python
"""
|
|
LLM Classifier - Real classification using OpenAI Responses API.
|
|
|
|
Uses JSON Schema to enforce strict output format.
|
|
Validates primitives against enabled set.
|
|
Stores raw response for audit.
|
|
Supports multilingual reviews with language detection.
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import re
|
|
import time
|
|
from typing import Any
|
|
|
|
from openai import OpenAI
|
|
|
|
# Language detection - try langdetect, fall back to heuristics
|
|
try:
|
|
from langdetect import detect as langdetect_detect, LangDetectException
|
|
LANGDETECT_AVAILABLE = True
|
|
except ImportError:
|
|
LANGDETECT_AVAILABLE = False
|
|
LangDetectException = Exception # Placeholder
|
|
|
|
|
|
def detect_language(text: str) -> tuple[str, float]:
|
|
"""
|
|
Detect the language of a text.
|
|
|
|
Returns (language_code, confidence).
|
|
Supported languages: en, es, de, fr, it, pt, ru, zh, ja, ko, ar, etc.
|
|
|
|
Falls back to heuristic detection if langdetect unavailable.
|
|
"""
|
|
if not text or len(text.strip()) < 3:
|
|
return "unknown", 0.0
|
|
|
|
text = text.strip()
|
|
|
|
# Try langdetect first (most accurate)
|
|
if LANGDETECT_AVAILABLE:
|
|
try:
|
|
lang = langdetect_detect(text)
|
|
# langdetect doesn't provide confidence directly, estimate based on text length
|
|
confidence = min(0.95, 0.5 + len(text) / 200)
|
|
return lang, confidence
|
|
except LangDetectException:
|
|
pass
|
|
|
|
# Fallback: Simple heuristic detection based on character ranges
|
|
# This is less accurate but works without dependencies
|
|
|
|
# Count characters in different scripts
|
|
latin = sum(1 for c in text if '\u0041' <= c <= '\u024F') # Latin extended
|
|
cyrillic = sum(1 for c in text if '\u0400' <= c <= '\u04FF') # Cyrillic
|
|
cjk = sum(1 for c in text if '\u4E00' <= c <= '\u9FFF') # CJK Unified
|
|
japanese = sum(1 for c in text if '\u3040' <= c <= '\u30FF') # Hiragana + Katakana
|
|
korean = sum(1 for c in text if '\uAC00' <= c <= '\uD7AF') # Hangul
|
|
arabic = sum(1 for c in text if '\u0600' <= c <= '\u06FF') # Arabic
|
|
|
|
total = len(text)
|
|
if total == 0:
|
|
return "unknown", 0.0
|
|
|
|
# Determine primary script
|
|
if cjk / total > 0.3:
|
|
return "zh", 0.6 # Chinese
|
|
if japanese / total > 0.2:
|
|
return "ja", 0.6 # Japanese
|
|
if korean / total > 0.3:
|
|
return "ko", 0.6 # Korean
|
|
if cyrillic / total > 0.3:
|
|
return "ru", 0.5 # Russian (could be other Cyrillic)
|
|
if arabic / total > 0.3:
|
|
return "ar", 0.5 # Arabic
|
|
|
|
if latin / total > 0.5:
|
|
# Latin script - try to distinguish languages by common words
|
|
text_lower = text.lower()
|
|
|
|
# Spanish indicators (expanded for better detection)
|
|
es_words = ['el', 'la', 'los', 'las', 'de', 'que', 'es', 'en', 'un', 'una',
|
|
'muy', 'pero', 'con', 'está', 'están', 'para', 'por', 'como',
|
|
'excelente', 'recomendado', 'servicio', 'bueno', 'malo', 'bien',
|
|
'todo', 'nada', 'más', 'sin', 'nunca', 'siempre', 'también']
|
|
es_score = sum(1 for w in es_words if re.search(rf'\b{w}\b', text_lower))
|
|
|
|
# Spanish-specific patterns (accents, ñ, inverted punctuation)
|
|
if 'ñ' in text_lower or '¿' in text or '¡' in text:
|
|
es_score += 3
|
|
if any(c in text_lower for c in 'áéíóúü'):
|
|
es_score += 1
|
|
|
|
# English indicators
|
|
en_words = ['the', 'and', 'is', 'are', 'was', 'were', 'this', 'that',
|
|
'with', 'for', 'but', 'not', 'very', 'great', 'good',
|
|
'service', 'place', 'food', 'staff', 'friendly', 'amazing',
|
|
'would', 'recommend', 'will', 'definitely', 'really']
|
|
en_score = sum(1 for w in en_words if re.search(rf'\b{w}\b', text_lower))
|
|
|
|
# German indicators
|
|
de_words = ['der', 'die', 'das', 'und', 'ist', 'sind', 'war', 'sehr',
|
|
'mit', 'für', 'aber', 'nicht', 'ein', 'eine', 'wir', 'ich',
|
|
'auch', 'gut', 'schlecht', 'toll', 'super']
|
|
de_score = sum(1 for w in de_words if re.search(rf'\b{w}\b', text_lower))
|
|
# German umlauts
|
|
if any(c in text_lower for c in 'äöüß'):
|
|
de_score += 2
|
|
|
|
# French indicators
|
|
fr_words = ['le', 'la', 'les', 'est', 'sont', 'très', 'mais', 'avec',
|
|
'pour', 'pas', 'un', 'une', 'et', 'nous', 'vous', 'bien',
|
|
'bon', 'mauvais', 'excellent', 'super', "c'est", "j'ai"]
|
|
fr_score = sum(1 for w in fr_words if re.search(rf'\b{w}\b', text_lower))
|
|
# French accents and patterns
|
|
if any(c in text_lower for c in 'àâçèêëîïôùûÿœæ'):
|
|
fr_score += 2
|
|
|
|
scores = {'es': es_score, 'en': en_score, 'de': de_score, 'fr': fr_score}
|
|
best_lang = max(scores, key=scores.get)
|
|
best_score = scores[best_lang]
|
|
|
|
if best_score >= 1: # Lowered threshold
|
|
confidence = min(0.75, 0.3 + best_score * 0.08)
|
|
return best_lang, confidence
|
|
|
|
# Default to English for Latin script
|
|
return "en", 0.3
|
|
|
|
return "unknown", 0.1
|
|
|
|
# Lazy client initialization
|
|
_client = None
|
|
|
|
|
|
def get_client() -> OpenAI:
|
|
"""Get OpenAI client, initializing lazily on first use."""
|
|
global _client
|
|
if _client is None:
|
|
api_key = os.environ.get("OPENAI_API_KEY")
|
|
if not api_key:
|
|
raise RuntimeError(
|
|
"OPENAI_API_KEY environment variable not set. "
|
|
"Set it or use --dry-run / mock classifier."
|
|
)
|
|
_client = OpenAI(api_key=api_key)
|
|
return _client
|
|
|
|
# Default model
|
|
DEFAULT_MODEL = os.environ.get("OPENAI_MODEL", "gpt-4o-mini")
|
|
|
|
# Meta primitives - always available
|
|
META_PRIMITIVES = frozenset([
|
|
"HONESTY", "ETHICS", "PROMISES",
|
|
"ACKNOWLEDGMENT", "RESPONSE_QUALITY", "RECOVERY",
|
|
"RETURN_INTENT", "RECOMMEND", "RECOGNITION",
|
|
"UNMAPPED",
|
|
])
|
|
|
|
# JSON Schema for structured output
|
|
SPAN_SCHEMA = {
|
|
"name": "review_classification",
|
|
"strict": True,
|
|
"schema": {
|
|
"type": "object",
|
|
"additionalProperties": False,
|
|
"properties": {
|
|
"spans": {
|
|
"type": "array",
|
|
"items": {
|
|
"type": "object",
|
|
"additionalProperties": False,
|
|
"properties": {
|
|
"primitive": {"type": "string"},
|
|
"valence": {"type": "string", "enum": ["positive", "negative", "mixed", "neutral"]},
|
|
"intensity": {"type": "integer", "minimum": 1, "maximum": 5},
|
|
"evidence": {"type": "string"},
|
|
"start_char": {"type": ["integer", "null"]},
|
|
"end_char": {"type": ["integer", "null"]},
|
|
"confidence": {"type": "number", "minimum": 0.0, "maximum": 1.0},
|
|
"details": {"type": "null"}
|
|
},
|
|
"required": ["primitive", "valence", "intensity", "evidence", "confidence", "start_char", "end_char", "details"]
|
|
}
|
|
},
|
|
"unmapped": {
|
|
"type": "array",
|
|
"items": {
|
|
"type": "object",
|
|
"additionalProperties": False,
|
|
"properties": {
|
|
"label": {"type": "string"},
|
|
"evidence": {"type": "string"},
|
|
"confidence": {"type": "number", "minimum": 0.0, "maximum": 1.0}
|
|
},
|
|
"required": ["label", "evidence", "confidence"]
|
|
}
|
|
}
|
|
},
|
|
"required": ["spans", "unmapped"]
|
|
}
|
|
}
|
|
|
|
# System prompt
|
|
SYSTEM_PROMPT = """You are a review classification system that extracts semantic spans and maps them to primitives.
|
|
|
|
## RULES (MUST FOLLOW)
|
|
|
|
1. Use ONLY primitives from the enabled_primitives list provided. Do NOT invent new primitives.
|
|
|
|
2. Meta primitives are always available: HONESTY, ETHICS, PROMISES, ACKNOWLEDGMENT, RESPONSE_QUALITY, RECOVERY, RETURN_INTENT, RECOMMEND, RECOGNITION, UNMAPPED
|
|
|
|
3. If content doesn't fit any enabled primitive, use UNMAPPED or put it in the unmapped array with a descriptive label.
|
|
|
|
4. Output MUST match the JSON schema exactly. No extra keys.
|
|
|
|
5. Evidence must be a SHORT EXACT QUOTE from the review text (in original language).
|
|
|
|
6. Extract 1-5 spans per review. Prefer fewer, larger spans over many small ones.
|
|
|
|
7. If unsure about classification, lower the confidence score.
|
|
|
|
## VALENCE
|
|
- positive: praise, satisfaction, recommendation
|
|
- negative: complaint, dissatisfaction, warning
|
|
- mixed: both positive and negative in same span
|
|
- neutral: factual observation, no sentiment
|
|
|
|
## INTENSITY (1-5)
|
|
- 1: mild ("okay", "fine")
|
|
- 2: moderate ("good", "bad")
|
|
- 3: strong ("great", "terrible")
|
|
- 4: very strong ("amazing", "awful")
|
|
- 5: extreme ("best ever", "worst nightmare")
|
|
|
|
## CONFIDENCE
|
|
- 0.9+: Very certain the primitive fits
|
|
- 0.7-0.9: Confident
|
|
- 0.5-0.7: Moderate confidence
|
|
- <0.5: Low confidence (consider UNMAPPED)
|
|
|
|
Output valid JSON only. No markdown, no explanations."""
|
|
|
|
|
|
def compute_review_hash(text: str, config_version: str) -> str:
|
|
"""Compute hash for caching."""
|
|
key = f"{config_version}:{text}"
|
|
return hashlib.sha256(key.encode()).hexdigest()[:16]
|
|
|
|
|
|
def build_user_payload(
|
|
review_text: str,
|
|
rating: int | None,
|
|
config: dict[str, Any],
|
|
language: str = "auto",
|
|
) -> dict[str, Any]:
|
|
"""Build the user message payload for the LLM."""
|
|
# Extract only what the model needs
|
|
enabled = set(config.get("enabled_primitives", []))
|
|
enabled.update(META_PRIMITIVES)
|
|
|
|
# Build primitive definitions (minimal)
|
|
primitives_dict = config.get("primitives", {})
|
|
primitive_defs = {}
|
|
for prim in enabled:
|
|
if prim in primitives_dict:
|
|
info = primitives_dict[prim]
|
|
primitive_defs[prim] = info.get("def", info.get("name", prim))
|
|
elif prim in META_PRIMITIVES:
|
|
primitive_defs[prim] = f"Meta primitive: {prim.replace('_', ' ').lower()}"
|
|
|
|
# Extract brief signals (keep it short)
|
|
brief = config.get("brief", {})
|
|
brief_summary = {}
|
|
if brief.get("what_customers_judge"):
|
|
items = brief["what_customers_judge"]
|
|
if isinstance(items, dict):
|
|
items = items.get("items", [])
|
|
brief_summary["key_judgment_areas"] = [
|
|
item.get("aspect", item.get("area", str(item))) if isinstance(item, dict) else str(item)
|
|
for item in items[:5]
|
|
]
|
|
if brief.get("critical_pain_points"):
|
|
pains = brief["critical_pain_points"]
|
|
if isinstance(pains, dict):
|
|
pains = pains.get("items", [])
|
|
brief_summary["critical_pains"] = [
|
|
item.get("pain", str(item)) if isinstance(item, dict) else str(item)
|
|
for item in pains[:3]
|
|
]
|
|
|
|
return {
|
|
"business": {
|
|
"name": config.get("business_id"),
|
|
"sector": config.get("sector_code"),
|
|
"config_version": config.get("config_version"),
|
|
},
|
|
"enabled_primitives": sorted(enabled),
|
|
"primitive_definitions": primitive_defs,
|
|
"weights": config.get("weights", {}),
|
|
"sector_brief": brief_summary,
|
|
"review": {
|
|
"text": review_text,
|
|
"rating": rating,
|
|
"language": language,
|
|
},
|
|
}
|
|
|
|
|
|
def validate_response(
|
|
response: dict[str, Any],
|
|
enabled_primitives: set[str],
|
|
) -> tuple[dict[str, Any], list[str]]:
|
|
"""
|
|
Validate LLM response and fix invalid primitives.
|
|
|
|
Returns (validated_response, warnings).
|
|
"""
|
|
warnings = []
|
|
all_valid = enabled_primitives | META_PRIMITIVES
|
|
|
|
validated_spans = []
|
|
for span in response.get("spans", []):
|
|
prim = span.get("primitive")
|
|
if prim not in all_valid:
|
|
warnings.append(f"Invalid primitive '{prim}' → UNMAPPED (original: {prim})")
|
|
span["primitive"] = "UNMAPPED"
|
|
validated_spans.append(span)
|
|
|
|
return {
|
|
"spans": validated_spans,
|
|
"unmapped": response.get("unmapped", []),
|
|
}, warnings
|
|
|
|
|
|
def classify_review(
|
|
review_text: str,
|
|
rating: int | None,
|
|
config: dict[str, Any],
|
|
language: str = "auto",
|
|
model: str | None = None,
|
|
max_retries: int = 3,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Classify a single review using OpenAI.
|
|
|
|
Args:
|
|
review_text: The review text to classify
|
|
rating: Star rating (1-5) if available
|
|
config: Resolved config from ConfigResolver
|
|
language: Language hint (default: auto-detect)
|
|
model: Model to use (default: gpt-4o-mini)
|
|
max_retries: Max retries on transient errors
|
|
|
|
Returns:
|
|
{
|
|
"spans": [...],
|
|
"unmapped": [...],
|
|
"model": str,
|
|
"raw_response": str,
|
|
"review_hash": str,
|
|
"warnings": [...],
|
|
"detected_language": str,
|
|
"language_confidence": float,
|
|
}
|
|
"""
|
|
model = model or DEFAULT_MODEL
|
|
|
|
# Detect language if auto
|
|
detected_lang = "unknown"
|
|
lang_confidence = 0.0
|
|
if language == "auto":
|
|
detected_lang, lang_confidence = detect_language(review_text)
|
|
language = detected_lang
|
|
else:
|
|
detected_lang = language
|
|
lang_confidence = 1.0 # User-specified
|
|
|
|
# Build payload with detected language
|
|
payload = build_user_payload(review_text, rating, config, detected_lang)
|
|
user_content = json.dumps(payload, ensure_ascii=False, indent=None)
|
|
|
|
# Compute hash for caching
|
|
review_hash = compute_review_hash(review_text, config.get("config_version", "1.0"))
|
|
|
|
# Call OpenAI with retries
|
|
last_error = None
|
|
client = get_client()
|
|
for attempt in range(max_retries):
|
|
try:
|
|
response = client.chat.completions.create(
|
|
model=model,
|
|
messages=[
|
|
{"role": "system", "content": SYSTEM_PROMPT},
|
|
{"role": "user", "content": user_content},
|
|
],
|
|
response_format={
|
|
"type": "json_schema",
|
|
"json_schema": SPAN_SCHEMA,
|
|
},
|
|
temperature=0.1, # Low temperature for consistency
|
|
max_tokens=2000,
|
|
)
|
|
|
|
# Parse response
|
|
raw_text = response.choices[0].message.content
|
|
parsed = json.loads(raw_text)
|
|
|
|
# Validate primitives
|
|
enabled = set(config.get("enabled_primitives", []))
|
|
validated, warnings = validate_response(parsed, enabled)
|
|
|
|
return {
|
|
"spans": validated["spans"],
|
|
"unmapped": validated["unmapped"],
|
|
"model": model,
|
|
"raw_response": raw_text,
|
|
"review_hash": review_hash,
|
|
"warnings": warnings,
|
|
"tokens": {
|
|
"prompt": response.usage.prompt_tokens if response.usage else 0,
|
|
"completion": response.usage.completion_tokens if response.usage else 0,
|
|
},
|
|
"detected_language": detected_lang,
|
|
"language_confidence": lang_confidence,
|
|
}
|
|
|
|
except json.JSONDecodeError as e:
|
|
last_error = f"JSON parse error: {e}"
|
|
# Don't retry parse errors - log and return fallback
|
|
break
|
|
|
|
except Exception as e:
|
|
last_error = str(e)
|
|
if "rate_limit" in str(e).lower() or "429" in str(e):
|
|
# Exponential backoff for rate limits
|
|
wait = 2 ** attempt
|
|
time.sleep(wait)
|
|
continue
|
|
elif "500" in str(e) or "502" in str(e) or "503" in str(e):
|
|
# Retry on server errors
|
|
time.sleep(1)
|
|
continue
|
|
else:
|
|
# Don't retry other errors
|
|
break
|
|
|
|
# Fallback response on error
|
|
return {
|
|
"spans": [{
|
|
"primitive": "UNMAPPED",
|
|
"valence": "neutral",
|
|
"intensity": 1,
|
|
"evidence": review_text[:100] if review_text else "",
|
|
"start_char": 0,
|
|
"end_char": min(100, len(review_text)) if review_text else 0,
|
|
"confidence": 0.1,
|
|
"details": {"error": last_error},
|
|
}],
|
|
"unmapped": [],
|
|
"model": model,
|
|
"raw_response": json.dumps({"error": last_error}),
|
|
"review_hash": review_hash,
|
|
"warnings": [f"Classification failed: {last_error}"],
|
|
"tokens": {"prompt": 0, "completion": 0},
|
|
"detected_language": detected_lang,
|
|
"language_confidence": lang_confidence,
|
|
}
|
|
|
|
|
|
async def classify_review_async(
|
|
review_text: str,
|
|
rating: int | None,
|
|
config: dict[str, Any],
|
|
language: str = "auto",
|
|
model: str | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Async wrapper for classify_review."""
|
|
import asyncio
|
|
loop = asyncio.get_event_loop()
|
|
return await loop.run_in_executor(
|
|
None,
|
|
lambda: classify_review(review_text, rating, config, language, model),
|
|
)
|
|
|
|
|
|
# Batch classification (for later optimization)
|
|
async def classify_batch(
|
|
reviews: list[dict[str, Any]],
|
|
config: dict[str, Any],
|
|
model: str | None = None,
|
|
max_concurrent: int = 5,
|
|
) -> list[dict[str, Any]]:
|
|
"""
|
|
Classify multiple reviews concurrently.
|
|
|
|
Args:
|
|
reviews: List of {"text": str, "rating": int, "language": str}
|
|
config: Resolved config
|
|
model: Model to use
|
|
max_concurrent: Max concurrent requests
|
|
|
|
Returns:
|
|
List of classification results
|
|
"""
|
|
import asyncio
|
|
|
|
semaphore = asyncio.Semaphore(max_concurrent)
|
|
|
|
async def classify_one(review: dict) -> dict:
|
|
async with semaphore:
|
|
return await classify_review_async(
|
|
review.get("text", ""),
|
|
review.get("rating"),
|
|
config,
|
|
review.get("language", "auto"),
|
|
model,
|
|
)
|
|
|
|
tasks = [classify_one(r) for r in reviews]
|
|
return await asyncio.gather(*tasks)
|