
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
import uvicorn, os, json, datetime, random, glob

DATA_DIR = os.environ.get("SRC_DATA_DIR", "./data")
LABEL_DIR = os.path.join(DATA_DIR, "labels")
VOCAB_PATH = os.path.join(DATA_DIR, "vocab", "vocab_ptBR_escola.json")
MODEL_DIR = "./models"

os.makedirs(LABEL_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

app = FastAPI(title="SRC Backend", version="0.2.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class LineHypothesis(BaseModel):
    y_top: float
    y_bottom: float
    suggestions: List[str]
    confidence: float

class RecognizeResponse(BaseModel):
    page_id: str
    timestamp: str
    language: str = "pt-BR"
    lines: List[LineHypothesis]

@app.get("/health")
def health():
    return {"status": "ok", "time": datetime.datetime.utcnow().isoformat()}

@app.get("/vocab")
def vocab():
    if os.path.exists(VOCAB_PATH):
        with open(VOCAB_PATH, "r", encoding="utf-8") as f:
            return json.load(f)
    return {"language":"pt-BR","whitelist":[]}

@app.post("/detect_lines")
async def detect_lines(file: UploadFile = File(...)):
    lines = []
    y = 0.10
    for _ in range(12):
        lines.append({"y_top": round(y,3), "y_bottom": round(y+0.03,3)})
        y += 0.04
    return {"lines": lines}

@app.post("/recognize", response_model=RecognizeResponse)
async def recognize(
    file: UploadFile = File(...),
    page_id: str = Form("page-0001"),
    language: str = Form("pt-BR")
):
    dummy_bank = [
        ["Introdução ao tema", "Introducao ao tema", "Introd. ao tema"],
        ["Desenvolvo o argumento principal", "Desenvolvo argumento", "Desenv. argumento"],
        ["Apresento exemplo pertinente", "Exemplo pertinente", "Apresento um exemplo"],
        ["Contraponho com outro ponto", "Contraponho outro ponto", "Aponto outra visão"],
        ["Concluo com proposta", "Proposta de intervenção", "Conclusão e proposta"],
        ["Retomo a tese", "Retomo a ideia inicial", "Retomo a posição"],
        ["Organizo coesão", "Mantenho coesão", "Emprego conectivos"],
        ["Cito referência", "Trago referência", "Uso referência"],
        ["Descrevo contexto", "Contextualizo o tema", "Apresento contexto"],
        ["Discuto contraponto", "Aponto contraponto", "Analiso contraponto"],
        ["Proponho intervenção", "Detalho intervenção", "Proposta concreta"],
        ["Finalizo com síntese", "Fecho com síntese", "Síntese final"]
    ]
    lines = []
    for i, ys in enumerate([0.10 + 0.04*k for k in range(12)]):
        conf = round(0.55 + 0.04 * i + random.uniform(-0.07, 0.07), 3)
        conf = max(0.15, min(0.95, conf))
        suggs = dummy_bank[i % len(dummy_bank)]
        lines.append(LineHypothesis(
            y_top=round(ys,3), y_bottom=round(ys+0.03,3),
            suggestions=suggs,
            confidence=conf
        ))
    return RecognizeResponse(
        page_id=page_id,
        timestamp=datetime.datetime.utcnow().isoformat(),
        language=language,
        lines=lines
    )

from pydantic import BaseModel
from typing import List, Optional

class LabelLine(BaseModel):
    y_top: float
    y_bottom: float
    chosen_text: str
    confidence: Optional[float] = None
    topk: Optional[List[str]] = None

class SaveLabelsBody(BaseModel):
    page_id: str
    labels: List[LabelLine]

@app.post("/save_labels")
async def save_labels(payload: SaveLabelsBody):
    path = os.path.join(LABEL_DIR, f"{payload.page_id}.json")
    record = {
        "page_id": payload.page_id,
        "saved_at": datetime.datetime.utcnow().isoformat(),
        "labels": [l.dict() for l in payload.labels]
    }
    with open(path, "w", encoding="utf-8") as f:
        json.dump(record, f, ensure_ascii=False, indent=2)
    return {"status": "ok", "path": path}

@app.get("/queue/active-learning")
def active_learning_queue(limit: int = 20):
    queue = []
    for lp in glob.glob(os.path.join(LABEL_DIR, "*.json")):
        with open(lp, "r", encoding="utf-8") as f:
            rec = json.load(f)
        for idx, l in enumerate(rec.get("labels", [])):
            conf = l.get("confidence", 0.5) or 0.5
            queue.append({
                "page_id": rec["page_id"],
                "line_index": idx,
                "confidence": conf,
                "y_top": l.get("y_top"),
                "y_bottom": l.get("y_bottom"),
                "chosen_text": l.get("chosen_text", ""),
            })
    queue.sort(key=lambda x: x["confidence"])
    return {"items": queue[:limit], "total": len(queue)}

@app.post("/train/incremental")
def train_incremental():
    total_labels = 0
    pages = 0
    for lp in glob.glob(os.path.join(LABEL_DIR, "*.json")):
        with open(lp, "r", encoding="utf-8") as f:
            rec = json.load(f)
        total_labels += len(rec.get("labels", []))
        pages += 1
    state = {
        "updated_at": datetime.datetime.utcnow().isoformat(),
        "pages": pages,
        "labels": total_labels,
        "note": "Mock incremental training complete. Replace with real HTR fine-tuning.",
        "adapters": {"ptBR_escola": True}
    }
    with open(os.path.join(MODEL_DIR, "model_state.json"), "w", encoding="utf-8") as f:
        json.dump(state, f, ensure_ascii=False, indent=2)
    return {"status": "ok", "state": state}

@app.get("/metrics")
def metrics():
    pages = 0
    labels = 0
    for lp in glob.glob(os.path.join(LABEL_DIR, "*.json")):
        with open(lp, "r", encoding="utf-8") as f:
            rec = json.load(f)
        labels += len(rec.get("labels", []))
        pages += 1
    model_state = {}
    ms_path = os.path.join(MODEL_DIR, "model_state.json")
    if os.path.exists(ms_path):
        with open(ms_path, "r", encoding="utf-8") as f:
            model_state = json.load(f)
    return {"pages": pages, "labels": labels, "model": model_state}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000)
