"""Pipeline configuration — tunable per category and stage.""" from __future__ import annotations from dataclasses import dataclass, field from typing import Literal PromptType = Literal["buying", "comparison", "specific_need", "informational", "brand_direct"] @dataclass class PipelineConfig: """Default config — override via command-line args or per-category YAML.""" # Stage 1 — Persona Generator num_personas: int = 7 persona_model: str = "claude-sonnet-4-6" # Stage 2 — Prompt Brainstormer prompts_per_persona: int = 30 brainstormer_model: str = "claude-sonnet-4-6" type_distribution: dict[PromptType, float] = field( default_factory=lambda: { "buying": 0.30, "comparison": 0.25, "specific_need": 0.20, "informational": 0.15, "brand_direct": 0.10, } ) # Stage 3 — Reality Checker google_trends_min_volume: int = 1 # PL queries per month, minimum signal reddit_min_organic_mentions: int = 3 fallback_to_quora_if_no_signal: bool = True # Stage 4 — Validation Agents flagged_by_n_critics_to_remove: int = 2 # Remove if 2+ agents flag it critic_models: dict[str, str] = field( default_factory=lambda: { "real_buyer_critique": "claude-sonnet-4-6", "methodology_critic": "claude-sonnet-4-6", "vendor_exploit_hunter": "claude-sonnet-4-6", } ) # Stage 5 — Pilot Test Runner pilot_sample_size: int = 10 pilot_models: list[str] = field( default_factory=lambda: [ "gpt-4o-search", "perplexity-sonar-pro", "gemini-pro", ] ) repetitions_per_prompt: int = 1 # In pilot test only, production uses 2+ # Final pool size after all filtering final_pool_size: int = 100 # Output paths data_dir: str = "data" # Public stage outputs prompts_dir: str = "../../prompts" # Closed final prompts (gitignored) CONFIG = PipelineConfig() # Type distribution as integer counts for final pool def get_target_counts(config: PipelineConfig = CONFIG) -> dict[PromptType, int]: """Return integer counts per prompt type for final pool of `final_pool_size`.""" counts = { ptype: int(round(config.final_pool_size * pct)) for ptype, pct in config.type_distribution.items() } # Adjust rounding to ensure sum == final_pool_size total = sum(counts.values()) if total != config.final_pool_size: # Adjust the largest category to absorb difference largest_type = max(counts, key=lambda k: counts[k]) counts[largest_type] += config.final_pool_size - total return counts if __name__ == "__main__": print(f"Final pool size: {CONFIG.final_pool_size}") print(f"Target counts per type: {get_target_counts()}")