227 lines
9.3 KiB
Python
227 lines
9.3 KiB
Python
"""
|
|
Config Resolver - Standalone version for scripts.
|
|
|
|
Resolves L1 config + sector brief for classification.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Paths
|
|
DATA_DIR = Path(__file__).parent.parent / "data"
|
|
CONFIGS_DIR = DATA_DIR / "primitive_configs" / "l1"
|
|
L2_CONFIGS_DIR = DATA_DIR / "primitive_configs" / "l2"
|
|
BRIEFS_DIR = DATA_DIR / "sector_briefs"
|
|
|
|
# Meta primitives - always enabled
|
|
META_PRIMITIVES = frozenset([
|
|
"HONESTY", "ETHICS", "PROMISES",
|
|
"ACKNOWLEDGMENT", "RESPONSE_QUALITY", "RECOVERY",
|
|
"RETURN_INTENT", "RECOMMEND", "RECOGNITION",
|
|
"UNMAPPED",
|
|
])
|
|
|
|
# Core primitives dictionary
|
|
CORE_PRIMITIVES = {
|
|
"TASTE": {"domain": "O", "name": "Taste/Flavor", "def": "Sensory quality of food/beverage"},
|
|
"CRAFT": {"domain": "O", "name": "Craftsmanship", "def": "Skill of execution/preparation"},
|
|
"FRESHNESS": {"domain": "O", "name": "Freshness", "def": "Newness, not stale or old"},
|
|
"TEMPERATURE": {"domain": "O", "name": "Temperature", "def": "Hot/cold as expected"},
|
|
"EFFECTIVENESS": {"domain": "O", "name": "Effectiveness", "def": "Achieves intended purpose"},
|
|
"ACCURACY": {"domain": "O", "name": "Accuracy", "def": "Correct, as ordered/specified"},
|
|
"CONDITION": {"domain": "O", "name": "Condition", "def": "Physical state, wear, damage"},
|
|
"CONSISTENCY": {"domain": "O", "name": "Consistency", "def": "Same quality each time"},
|
|
"MANNER": {"domain": "P", "name": "Manner/Attitude", "def": "Friendliness, respect, warmth"},
|
|
"COMPETENCE": {"domain": "P", "name": "Competence", "def": "Knowledge and skill of staff"},
|
|
"ATTENTIVENESS": {"domain": "P", "name": "Attentiveness", "def": "Being present, responsive"},
|
|
"COMMUNICATION": {"domain": "P", "name": "Communication", "def": "Clarity, listening, updates"},
|
|
"SPEED": {"domain": "J", "name": "Speed/Wait", "def": "Time to service, waiting"},
|
|
"FRICTION": {"domain": "J", "name": "Friction", "def": "Obstacles, hassles, complexity"},
|
|
"RELIABILITY": {"domain": "J", "name": "Reliability", "def": "Dependable, keeps promises"},
|
|
"AVAILABILITY": {"domain": "J", "name": "Availability", "def": "Open when needed, bookable"},
|
|
"CLEANLINESS": {"domain": "E", "name": "Cleanliness", "def": "Hygiene, tidiness"},
|
|
"COMFORT": {"domain": "E", "name": "Comfort", "def": "Physical ease, seating"},
|
|
"SAFETY": {"domain": "E", "name": "Safety", "def": "Free from harm/danger"},
|
|
"AMBIANCE": {"domain": "E", "name": "Ambiance", "def": "Atmosphere, mood, vibe"},
|
|
"ACCESSIBILITY": {"domain": "E", "name": "Accessibility", "def": "Easy to reach, navigate"},
|
|
"DIGITAL_UX": {"domain": "E", "name": "Digital Experience", "def": "Website, app, online"},
|
|
"PRICE_LEVEL": {"domain": "V", "name": "Price Level", "def": "Absolute cost (cheap/expensive)"},
|
|
"PRICE_FAIRNESS": {"domain": "V", "name": "Price Fairness", "def": "Reasonable for what you get"},
|
|
"PRICE_TRANSPARENCY": {"domain": "V", "name": "Price Transparency", "def": "No hidden fees, clear pricing"},
|
|
"VALUE_FOR_MONEY": {"domain": "V", "name": "Value for Money", "def": "Worth what you paid"},
|
|
}
|
|
|
|
|
|
class ConfigResolver:
|
|
"""Resolves classification config for a business."""
|
|
|
|
def __init__(self):
|
|
self._l1_cache: dict[str, dict] = {}
|
|
self._l2_cache: dict[str, dict] = {}
|
|
self._brief_cache: dict[str, dict] = {}
|
|
|
|
def _load_l2_configs(self) -> list[dict[str, Any]]:
|
|
"""Load all L2 config files."""
|
|
if not L2_CONFIGS_DIR.exists():
|
|
return []
|
|
|
|
configs = []
|
|
for config_path in L2_CONFIGS_DIR.glob("*_config.json"):
|
|
try:
|
|
with open(config_path) as f:
|
|
config = json.load(f)
|
|
configs.append(config)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to load L2 config {config_path}: {e}")
|
|
return configs
|
|
|
|
def _find_matching_l2(self, gbp_path: str) -> dict[str, Any] | None:
|
|
"""Find L2 config that matches the GBP path (most specific wins)."""
|
|
l2_configs = self._load_l2_configs()
|
|
|
|
# Find all matching configs (path starts with L2 gbp_path)
|
|
matches = []
|
|
for config in l2_configs:
|
|
l2_path = config.get("gbp_path", "")
|
|
if gbp_path.startswith(l2_path) or gbp_path == l2_path:
|
|
matches.append((len(l2_path), config))
|
|
|
|
if not matches:
|
|
return None
|
|
|
|
# Return most specific match (longest path)
|
|
matches.sort(key=lambda x: x[0], reverse=True)
|
|
return matches[0][1]
|
|
|
|
def _apply_l2_delta(self, l1_config: dict, l2_config: dict) -> dict:
|
|
"""Apply L2 delta to L1 config."""
|
|
result = l1_config.copy()
|
|
delta = l2_config.get("delta", {})
|
|
|
|
# Enable additional primitives
|
|
if "enable" in delta:
|
|
enabled = set(result.get("enabled", []))
|
|
enabled.update(delta["enable"])
|
|
result["enabled"] = list(enabled)
|
|
|
|
# Merge weights
|
|
if "weights" in delta:
|
|
weights = dict(result.get("weights", {}))
|
|
weights.update(delta["weights"])
|
|
result["weights"] = weights
|
|
|
|
# Update config version to indicate L2
|
|
result["config_version"] = l2_config.get("config_version", result.get("config_version", "1.0"))
|
|
result["l2_applied"] = l2_config.get("gbp_path")
|
|
|
|
return result
|
|
|
|
def _load_l1_config(self, sector_code: str) -> dict[str, Any] | None:
|
|
if sector_code in self._l1_cache:
|
|
return self._l1_cache[sector_code]
|
|
|
|
config_path = CONFIGS_DIR / f"{sector_code.lower()}_config.json"
|
|
if not config_path.exists():
|
|
return None
|
|
|
|
with open(config_path) as f:
|
|
config = json.load(f)
|
|
|
|
self._l1_cache[sector_code] = config
|
|
return config
|
|
|
|
def _load_sector_brief(self, sector_code: str) -> dict[str, Any] | None:
|
|
if sector_code in self._brief_cache:
|
|
return self._brief_cache[sector_code]
|
|
|
|
brief_path = BRIEFS_DIR / f"{sector_code.lower()}_brief.json"
|
|
if not brief_path.exists():
|
|
return None
|
|
|
|
with open(brief_path) as f:
|
|
brief = json.load(f)
|
|
|
|
self._brief_cache[sector_code] = brief
|
|
return brief
|
|
|
|
async def get_business_mapping(self, pool, business_id: str) -> dict[str, Any] | None:
|
|
query = """
|
|
SELECT business_id, gbp_path::text, sector_code
|
|
FROM pipeline.business_taxonomy_map
|
|
WHERE business_id = $1
|
|
"""
|
|
row = await pool.fetchrow(query, business_id)
|
|
return dict(row) if row else None
|
|
|
|
def resolve_enabled_set(self, l1_config: dict) -> set[str]:
|
|
enabled = set(l1_config.get("enabled", []))
|
|
enabled.update(META_PRIMITIVES)
|
|
return enabled
|
|
|
|
def build_primitives_for_prompt(self, enabled: set[str], weights: dict[str, float]) -> dict[str, dict]:
|
|
result = {}
|
|
for prim in enabled:
|
|
if prim in CORE_PRIMITIVES:
|
|
entry = CORE_PRIMITIVES[prim].copy()
|
|
if prim in weights:
|
|
entry["weight"] = weights[prim]
|
|
result[prim] = entry
|
|
elif prim in META_PRIMITIVES:
|
|
result[prim] = {"domain": "M", "name": prim.replace("_", " ").title(), "meta": True}
|
|
return result
|
|
|
|
def extract_brief_signals(self, brief: dict) -> dict[str, Any]:
|
|
if not brief:
|
|
return {}
|
|
return {
|
|
"sector": brief.get("sector_code"),
|
|
"what_customers_judge": brief.get("what_customers_judge"),
|
|
"critical_pain_points": brief.get("critical_pain_points"),
|
|
"industry_terminology": brief.get("industry_terminology"),
|
|
}
|
|
|
|
async def resolve(self, business_id: str, pool, mode: str | None = None) -> dict[str, Any] | None:
|
|
mapping = await self.get_business_mapping(pool, business_id)
|
|
if not mapping:
|
|
return None
|
|
|
|
sector_code = mapping["sector_code"]
|
|
gbp_path = mapping["gbp_path"]
|
|
|
|
# Load L1 config (sector-level)
|
|
l1_config = self._load_l1_config(sector_code)
|
|
if not l1_config:
|
|
l1_config = {"enabled": list(CORE_PRIMITIVES.keys()), "weights": {}}
|
|
|
|
# Check for L2 config (category-level delta)
|
|
l2_config = self._find_matching_l2(gbp_path)
|
|
if l2_config:
|
|
logger.info(f"Applying L2 delta for {gbp_path}: {l2_config.get('gbp_path')}")
|
|
l1_config = self._apply_l2_delta(l1_config, l2_config)
|
|
|
|
brief = self._load_sector_brief(sector_code)
|
|
|
|
enabled = self.resolve_enabled_set(l1_config)
|
|
weights = dict(l1_config.get("weights", {}))
|
|
primitives = self.build_primitives_for_prompt(enabled, weights)
|
|
brief_signals = self.extract_brief_signals(brief)
|
|
|
|
return {
|
|
"business_id": business_id,
|
|
"gbp_path": gbp_path,
|
|
"sector_code": sector_code,
|
|
"config_version": l1_config.get("config_version", "1.0"),
|
|
"l2_applied": l1_config.get("l2_applied"),
|
|
"modes": [mode] if mode else ["in_person"],
|
|
"default_mode": mode or "in_person",
|
|
"enabled_primitives": sorted(enabled),
|
|
"disabled_primitives": sorted(l1_config.get("disabled", [])),
|
|
"weights": weights,
|
|
"brief": brief_signals,
|
|
"primitives": primitives,
|
|
}
|