Spaces:
Sleeping
Sleeping
Cuong2004 commited on
Commit ·
25d12dc
0
Parent(s):
Deploy Worker from GitHub Actions
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- Dockerfile +46 -0
- README.md +11 -0
- agents/geometry_agent.py +120 -0
- agents/knowledge_agent.py +135 -0
- agents/ocr_agent.py +185 -0
- agents/orchestrator.py +249 -0
- agents/parser_agent.py +106 -0
- agents/renderer_agent.py +249 -0
- agents/solver_agent.py +107 -0
- agents/torch_ultralytics_compat.py +33 -0
- app/dependencies.py +62 -0
- app/errors.py +59 -0
- app/llm_client.py +104 -0
- app/logging_setup.py +112 -0
- app/logutil.py +67 -0
- app/main.py +125 -0
- app/models/schemas.py +66 -0
- app/routers/__init__.py +1 -0
- app/routers/auth.py +23 -0
- app/routers/sessions.py +165 -0
- app/routers/solve.py +204 -0
- app/runtime_env.py +12 -0
- app/session_cache.py +48 -0
- app/supabase_client.py +37 -0
- app/url_utils.py +23 -0
- app/websocket_manager.py +40 -0
- clean_ports.sh +22 -0
- migrations/v4_migration.sql +95 -0
- requirements.txt +34 -0
- run_api_test.sh +65 -0
- run_full_api_test.sh +60 -0
- scripts/backend_test_suite.py +97 -0
- scripts/generate_report.py +73 -0
- scripts/prepare_api_test.py +31 -0
- scripts/prewarm_models.py +42 -0
- scripts/test_engine_direct.py +36 -0
- setup.sh +43 -0
- solver/dsl_parser.py +210 -0
- solver/engine.py +426 -0
- solver/models.py +13 -0
- tests/test_3d_solver.py +85 -0
- tests/test_advanced_geometry.py +102 -0
- tests/test_api_full_suite.py +237 -0
- tests/test_api_metadata_real.py +56 -0
- tests/test_api_real_e2e.py +75 -0
- tests/test_direct_task.py +70 -0
- tests/test_full_pipeline.py +237 -0
- tests/test_openrouter.py +92 -0
- tests/test_real_llm.py +30 -0
- tests/test_solver.py +44 -0
Dockerfile
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Same runtime as API; runs health endpoint + Celery worker (see worker_health.py)
|
| 2 |
+
FROM python:3.11-slim-bookworm
|
| 3 |
+
|
| 4 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 5 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 6 |
+
PIP_NO_CACHE_DIR=1 \
|
| 7 |
+
PIP_ROOT_USER_ACTION=ignore \
|
| 8 |
+
NO_ALBUMENTATIONS_UPDATE=1 \
|
| 9 |
+
OMP_NUM_THREADS=1 \
|
| 10 |
+
MKL_NUM_THREADS=1 \
|
| 11 |
+
OPENBLAS_NUM_THREADS=1
|
| 12 |
+
|
| 13 |
+
WORKDIR /app
|
| 14 |
+
ENV PYTHONPATH=/app
|
| 15 |
+
|
| 16 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 17 |
+
ffmpeg \
|
| 18 |
+
pkg-config \
|
| 19 |
+
cmake \
|
| 20 |
+
libcairo2 \
|
| 21 |
+
libcairo2-dev \
|
| 22 |
+
libpango-1.0-0 \
|
| 23 |
+
libpango1.0-dev \
|
| 24 |
+
libpangocairo-1.0-0 \
|
| 25 |
+
libgdk-pixbuf-2.0-0 \
|
| 26 |
+
libffi-dev \
|
| 27 |
+
python3-dev \
|
| 28 |
+
texlive-latex-base \
|
| 29 |
+
texlive-fonts-recommended \
|
| 30 |
+
texlive-latex-extra \
|
| 31 |
+
build-essential \
|
| 32 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 33 |
+
|
| 34 |
+
COPY requirements.txt .
|
| 35 |
+
RUN pip install --upgrade pip setuptools wheel \
|
| 36 |
+
&& pip install -r requirements.txt
|
| 37 |
+
|
| 38 |
+
COPY . .
|
| 39 |
+
|
| 40 |
+
RUN python scripts/prewarm_models.py
|
| 41 |
+
|
| 42 |
+
ENV PORT=7860
|
| 43 |
+
EXPOSE 7860
|
| 44 |
+
|
| 45 |
+
ENTRYPOINT []
|
| 46 |
+
CMD ["sh", "-c", "exec python3 -u worker_health.py"]
|
README.md
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Math Solver Worker
|
| 3 |
+
emoji: 👷
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# Math Solver Worker
|
| 11 |
+
This space hosts the Celery background worker for video rendering.
|
agents/geometry_agent.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
from openai import AsyncOpenAI
|
| 5 |
+
from typing import Dict, Any
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
from app.url_utils import openai_compatible_api_key, sanitize_env
|
| 12 |
+
from app.llm_client import get_llm_client
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GeometryAgent:
|
| 16 |
+
def __init__(self):
|
| 17 |
+
self.llm = get_llm_client()
|
| 18 |
+
|
| 19 |
+
async def generate_dsl(self, semantic_data: Dict[str, Any], previous_dsl: str = None) -> str:
|
| 20 |
+
logger.info("==[GeometryAgent] Generating DSL from semantic data==")
|
| 21 |
+
if previous_dsl:
|
| 22 |
+
logger.info(f"[GeometryAgent] Using previous DSL context (len={len(previous_dsl)})")
|
| 23 |
+
|
| 24 |
+
system_prompt = """
|
| 25 |
+
You are a Geometry DSL Generator. Convert semantic geometry data into a precise Geometry DSL program.
|
| 26 |
+
|
| 27 |
+
=== MULTI-TURN CONTEXT ===
|
| 28 |
+
If a PREVIOUS DSL is provided, your job is to UPDATE or EXTEND it.
|
| 29 |
+
1. DO NOT remove existing points unless the user explicitly asks to "redefine" or "move" them.
|
| 30 |
+
2. Ensure new segments/points connect correctly to existing ones.
|
| 31 |
+
3. Your output should be the ENTIRE updated DSL, not just the changes.
|
| 32 |
+
|
| 33 |
+
=== DSL COMMANDS ===
|
| 34 |
+
POINT(A) — declare a point
|
| 35 |
+
POINT(A, x, y, z) — declare a point with explicit coordinates
|
| 36 |
+
LENGTH(AB, 5) — distance between A and B is 5 (2D/3D)
|
| 37 |
+
ANGLE(A, 90) — interior angle at vertex A is 90° (2D/3D)
|
| 38 |
+
PARALLEL(AB, CD) — segment AB is parallel to CD (2D/3D)
|
| 39 |
+
PERPENDICULAR(AB, CD) — segment AB is perpendicular to CD (2D/3D)
|
| 40 |
+
MIDPOINT(M, AB) — M is the midpoint of segment AB
|
| 41 |
+
SECTION(E, A, C, k) — E satisfies vector AE = k * vector AC (k is decimal)
|
| 42 |
+
LINE(A, B) — infinite line passing through A and B
|
| 43 |
+
RAY(A, B) — ray starting at A and passing through B
|
| 44 |
+
CIRCLE(O, 5) — circle with center O and radius 5 (2D)
|
| 45 |
+
SPHERE(O, 5) — sphere with center O and radius 5 (3D)
|
| 46 |
+
SEGMENT(M, N) — auxiliary segment MN to be drawn
|
| 47 |
+
POLYGON_ORDER(A, B, C, D) — the order in which vertices form the polygon boundary
|
| 48 |
+
TRIANGLE(ABC) — equilateral/arbitrary triangle
|
| 49 |
+
PYRAMID(S_ABCD) — pyramid with apex S and base ABCD
|
| 50 |
+
PRISM(ABC_DEF) — triangular prism
|
| 51 |
+
|
| 52 |
+
=== RULES ===
|
| 53 |
+
1. 3D Coordinates: Use POINT(A, x, y, z) if specific coordinates are given in the problem.
|
| 54 |
+
2. Space Geometry: For pyramids/prisms, use the specialized commands.
|
| 55 |
+
3. Primary Vertices: Always declare the main vertices of the shape (e.g., A, B, C, D) using POINT(X).
|
| 56 |
+
4. POLYGON_ORDER: Always emit POLYGON_ORDER(...) for the main shape using ONLY these primary vertices.
|
| 57 |
+
5. All Points: EVERY point mentioned (A, B, C, H, M, etc.) MUST be declared with POINT(Name) first.
|
| 58 |
+
6. Altitudes/Perpendiculars: For an altitude AH to BC, use POINT(H) + PERPENDICULAR(AH, BC).
|
| 59 |
+
7. Format: Output ONLY DSL lines — NO explanation, NO markdown, NO code blocks.
|
| 60 |
+
|
| 61 |
+
=== SHAPE EXAMPLES ===
|
| 62 |
+
|
| 63 |
+
--- Case: Square Pyramid S.ABCD with side 10, height 15 ---
|
| 64 |
+
PYRAMID(S_ABCD)
|
| 65 |
+
POINT(A, 0, 0, 0)
|
| 66 |
+
POINT(B, 10, 0, 0)
|
| 67 |
+
POINT(C, 10, 10, 0)
|
| 68 |
+
POINT(D, 0, 10, 0)
|
| 69 |
+
POINT(S)
|
| 70 |
+
POINT(O)
|
| 71 |
+
SECTION(O, A, C, 0.5)
|
| 72 |
+
LENGTH(SO, 15)
|
| 73 |
+
PERPENDICULAR(SO, AC)
|
| 74 |
+
PERPENDICULAR(SO, AB)
|
| 75 |
+
POLYGON_ORDER(A, B, C, D)
|
| 76 |
+
|
| 77 |
+
--- Case: Right Triangle ABC at A, AB=3, AC=4, altitude AH ---
|
| 78 |
+
POLYGON_ORDER(A, B, C)
|
| 79 |
+
POINT(A)
|
| 80 |
+
POINT(B)
|
| 81 |
+
POINT(C)
|
| 82 |
+
POINT(H)
|
| 83 |
+
LENGTH(AB, 3)
|
| 84 |
+
LENGTH(AC, 4)
|
| 85 |
+
ANGLE(A, 90)
|
| 86 |
+
PERPENDICULAR(AH, BC)
|
| 87 |
+
SEGMENT(A, H)
|
| 88 |
+
|
| 89 |
+
--- Case: Rectangle ABCD with AB=5, AD=10 ---
|
| 90 |
+
POLYGON_ORDER(A, B, C, D)
|
| 91 |
+
POINT(A)
|
| 92 |
+
POINT(B)
|
| 93 |
+
POINT(C)
|
| 94 |
+
POINT(D)
|
| 95 |
+
LENGTH(AB, 5)
|
| 96 |
+
LENGTH(AD, 10)
|
| 97 |
+
PERPENDICULAR(AB, AD)
|
| 98 |
+
PARALLEL(AB, CD)
|
| 99 |
+
PARALLEL(AD, BC)
|
| 100 |
+
|
| 101 |
+
[Circle with center O radius 7]
|
| 102 |
+
POINT(O)
|
| 103 |
+
CIRCLE(O, 7)
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
user_content = f"Semantic Data: {json.dumps(semantic_data, ensure_ascii=False)}"
|
| 107 |
+
if previous_dsl:
|
| 108 |
+
user_content = f"PREVIOUS DSL:\n{previous_dsl}\n\nUPDATE WITH NEW DATA: {json.dumps(semantic_data, ensure_ascii=False)}"
|
| 109 |
+
|
| 110 |
+
logger.debug("[GeometryAgent] Calling LLM (Multi-Layer)...")
|
| 111 |
+
content = await self.llm.chat_completions_create(
|
| 112 |
+
messages=[
|
| 113 |
+
{"role": "system", "content": system_prompt},
|
| 114 |
+
{"role": "user", "content": user_content}
|
| 115 |
+
]
|
| 116 |
+
)
|
| 117 |
+
dsl = content.strip() if content else ""
|
| 118 |
+
logger.info(f"[GeometryAgent] DSL generated ({len(dsl.splitlines())} lines).")
|
| 119 |
+
logger.debug(f"[GeometryAgent] DSL output:\n{dsl}")
|
| 120 |
+
return dsl
|
agents/knowledge_agent.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Dict, Any
|
| 3 |
+
|
| 4 |
+
logger = logging.getLogger(__name__)
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# ─── Shape rule registry ────────────────────────────────────────────────────
|
| 8 |
+
# Each entry: keyword list → augmentation function
|
| 9 |
+
# Augmentation receives (values: dict, text: str) and returns updated values dict.
|
| 10 |
+
|
| 11 |
+
class KnowledgeAgent:
|
| 12 |
+
"""Knowledge Agent: Stores geometric theorems and common patterns to augment Parser output."""
|
| 13 |
+
|
| 14 |
+
def augment_semantic_data(self, semantic_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 15 |
+
logger.info("==[KnowledgeAgent] Augmenting semantic data==")
|
| 16 |
+
text = str(semantic_data.get("input_text", "")).lower()
|
| 17 |
+
logger.debug(f"[KnowledgeAgent] Input text for matching: '{text[:200]}'")
|
| 18 |
+
|
| 19 |
+
shape_type = self._detect_shape(text, semantic_data.get("type", ""))
|
| 20 |
+
if shape_type:
|
| 21 |
+
semantic_data["type"] = shape_type
|
| 22 |
+
values = semantic_data.get("values", {})
|
| 23 |
+
values = self._augment_values(shape_type, values, text)
|
| 24 |
+
semantic_data["values"] = values
|
| 25 |
+
else:
|
| 26 |
+
logger.info("[KnowledgeAgent] No special rule matched. Returning data unchanged.")
|
| 27 |
+
|
| 28 |
+
logger.debug(f"[KnowledgeAgent] Output semantic data: {semantic_data}")
|
| 29 |
+
return semantic_data
|
| 30 |
+
|
| 31 |
+
# ─── Shape detection ────────────────────────────────────────────────────
|
| 32 |
+
def _detect_shape(self, text: str, llm_type: str) -> str | None:
|
| 33 |
+
"""Detect shape from text keywords. LLM type provides a hint."""
|
| 34 |
+
checks = [
|
| 35 |
+
(["hình vuông", "square"], "square"),
|
| 36 |
+
(["hình chữ nhật", "rectangle"], "rectangle"),
|
| 37 |
+
(["hình thoi", "rhombus"], "rhombus"),
|
| 38 |
+
(["hình bình hành", "parallelogram"], "parallelogram"),
|
| 39 |
+
(["hình thang vuông"], "right_trapezoid"),
|
| 40 |
+
(["hình thang", "trapezoid", "trapezium"], "trapezoid"),
|
| 41 |
+
(["tam giác vuông", "right triangle"], "right_triangle"),
|
| 42 |
+
(["tam giác đều", "equilateral triangle", "equilateral"], "equilateral_triangle"),
|
| 43 |
+
(["tam giác cân", "isosceles"], "isosceles_triangle"),
|
| 44 |
+
(["tam giác", "triangle"], "triangle"),
|
| 45 |
+
(["đường tròn", "circle"], "circle"),
|
| 46 |
+
]
|
| 47 |
+
for keywords, shape in checks:
|
| 48 |
+
if any(kw in text for kw in keywords):
|
| 49 |
+
logger.info(f"[KnowledgeAgent] Rule MATCH: '{shape}' detected (keyword match).")
|
| 50 |
+
return shape
|
| 51 |
+
|
| 52 |
+
# Fallback: trust LLM-detected type if it's a known type
|
| 53 |
+
known = {
|
| 54 |
+
"rectangle", "square", "rhombus", "parallelogram",
|
| 55 |
+
"trapezoid", "right_trapezoid", "triangle", "right_triangle",
|
| 56 |
+
"equilateral_triangle", "isosceles_triangle", "circle",
|
| 57 |
+
}
|
| 58 |
+
if llm_type in known:
|
| 59 |
+
logger.info(f"[KnowledgeAgent] Using LLM-detected type '{llm_type}'.")
|
| 60 |
+
return llm_type
|
| 61 |
+
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
# ─── Value augmentation ──────────────────────────────────────────────────
|
| 65 |
+
def _augment_values(self, shape: str, values: dict, text: str) -> dict:
|
| 66 |
+
ab = values.get("AB")
|
| 67 |
+
ad = values.get("AD")
|
| 68 |
+
bc = values.get("BC")
|
| 69 |
+
cd = values.get("CD")
|
| 70 |
+
|
| 71 |
+
if shape == "rectangle":
|
| 72 |
+
if ab and ad:
|
| 73 |
+
values.setdefault("CD", ab)
|
| 74 |
+
values.setdefault("BC", ad)
|
| 75 |
+
values.setdefault("angle_A", 90)
|
| 76 |
+
logger.info(f"[KnowledgeAgent] Rectangle: AB=CD={ab}, AD=BC={ad}, angle_A=90°")
|
| 77 |
+
else:
|
| 78 |
+
values.setdefault("angle_A", 90)
|
| 79 |
+
|
| 80 |
+
elif shape == "square":
|
| 81 |
+
side = ab or ad or bc or cd or values.get("side")
|
| 82 |
+
if side:
|
| 83 |
+
values.update({"AB": side, "AD": side, "angle_A": 90})
|
| 84 |
+
logger.info(f"[KnowledgeAgent] Square: side={side}, angle_A=90°")
|
| 85 |
+
else:
|
| 86 |
+
values.setdefault("angle_A", 90)
|
| 87 |
+
|
| 88 |
+
elif shape == "rhombus":
|
| 89 |
+
side = ab or values.get("side")
|
| 90 |
+
if side:
|
| 91 |
+
values.update({"AB": side, "BC": side, "CD": side, "DA": side})
|
| 92 |
+
logger.info(f"[KnowledgeAgent] Rhombus: all sides={side}")
|
| 93 |
+
|
| 94 |
+
elif shape == "parallelogram":
|
| 95 |
+
if ab:
|
| 96 |
+
values.setdefault("CD", ab)
|
| 97 |
+
if ad:
|
| 98 |
+
values.setdefault("BC", ad)
|
| 99 |
+
logger.info(f"[KnowledgeAgent] Parallelogram: AB||CD, AD||BC")
|
| 100 |
+
|
| 101 |
+
elif shape == "trapezoid":
|
| 102 |
+
logger.info("[KnowledgeAgent] Trapezoid: AB||CD (bottom||top)")
|
| 103 |
+
|
| 104 |
+
elif shape == "right_trapezoid":
|
| 105 |
+
logger.info("[KnowledgeAgent] Right trapezoid: AB||CD, AD⊥AB")
|
| 106 |
+
values.setdefault("angle_A", 90)
|
| 107 |
+
|
| 108 |
+
elif shape == "equilateral_triangle":
|
| 109 |
+
side = ab or values.get("side")
|
| 110 |
+
if side:
|
| 111 |
+
values.update({"AB": side, "BC": side, "CA": side, "angle_A": 60})
|
| 112 |
+
logger.info(f"[KnowledgeAgent] Equilateral triangle: all sides={side}, angle_A=60°")
|
| 113 |
+
|
| 114 |
+
elif shape == "right_triangle":
|
| 115 |
+
# Try to infer which vertex is the right angle
|
| 116 |
+
rt_vertex = _detect_right_angle_vertex(text)
|
| 117 |
+
values.setdefault(f"angle_{rt_vertex}", 90)
|
| 118 |
+
logger.info(f"[KnowledgeAgent] Right triangle: angle_{rt_vertex}=90°")
|
| 119 |
+
|
| 120 |
+
elif shape == "isosceles_triangle":
|
| 121 |
+
logger.info("[KnowledgeAgent] Isosceles triangle: AB=AC (default, LLM may override)")
|
| 122 |
+
|
| 123 |
+
elif shape == "circle":
|
| 124 |
+
logger.info("[KnowledgeAgent] Circle detected — no side augmentation needed.")
|
| 125 |
+
|
| 126 |
+
return values
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _detect_right_angle_vertex(text: str) -> str:
|
| 130 |
+
"""Heuristic: detect which vertex is right angle from text."""
|
| 131 |
+
for vertex in ["A", "B", "C", "D"]:
|
| 132 |
+
patterns = [f"vuông tại {vertex}", f"góc {vertex} vuông", f"right angle at {vertex}"]
|
| 133 |
+
if any(p.lower() in text for p in patterns):
|
| 134 |
+
return vertex
|
| 135 |
+
return "A" # default
|
agents/ocr_agent.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
import asyncio
|
| 4 |
+
from typing import List, Dict, Any
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ImprovedOCRAgent:
|
| 10 |
+
"""
|
| 11 |
+
Advanced OCR Agent using a hybrid pipeline:
|
| 12 |
+
1. YOLO for layout analysis (text vs formula).
|
| 13 |
+
2. PaddleOCR for Vietnamese text extraction.
|
| 14 |
+
3. Pix2Tex for LaTeX formula extraction.
|
| 15 |
+
4. MegaLLM for semantic correction and formatting.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self):
|
| 19 |
+
logger.info("[ImprovedOCRAgent] Initializing engines and client...")
|
| 20 |
+
|
| 21 |
+
from app.llm_client import get_llm_client
|
| 22 |
+
self.llm = get_llm_client()
|
| 23 |
+
logger.info("[ImprovedOCRAgent] Multi-Layer LLM Client initialized.")
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from agents.torch_ultralytics_compat import allow_ultralytics_weights
|
| 27 |
+
from ultralytics import YOLO
|
| 28 |
+
|
| 29 |
+
allow_ultralytics_weights()
|
| 30 |
+
logger.info("[ImprovedOCRAgent] Loading YOLO...")
|
| 31 |
+
self.layout_model = YOLO("yolov8n.pt")
|
| 32 |
+
logger.info("[ImprovedOCRAgent] YOLO initialized.")
|
| 33 |
+
except Exception as e:
|
| 34 |
+
logger.error("[ImprovedOCRAgent] YOLO init failed: %s", e)
|
| 35 |
+
self.layout_model = None
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
from paddleocr import PaddleOCR
|
| 39 |
+
|
| 40 |
+
logger.info("[ImprovedOCRAgent] Loading PaddleOCR...")
|
| 41 |
+
self.text_model = PaddleOCR(use_angle_cls=True, lang="vi")
|
| 42 |
+
logger.info("[ImprovedOCRAgent] PaddleOCR (vi) initialized.")
|
| 43 |
+
except Exception as e:
|
| 44 |
+
logger.error("[ImprovedOCRAgent] PaddleOCR init failed: %s", e)
|
| 45 |
+
self.text_model = None
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
from pix2tex.cli import LatexOCR
|
| 49 |
+
|
| 50 |
+
logger.info("[ImprovedOCRAgent] Loading Pix2Tex...")
|
| 51 |
+
self.math_model = LatexOCR()
|
| 52 |
+
logger.info("[ImprovedOCRAgent] Pix2Tex initialized.")
|
| 53 |
+
except Exception as e:
|
| 54 |
+
logger.error("[ImprovedOCRAgent] Pix2Tex init failed: %s", e)
|
| 55 |
+
self.math_model = None
|
| 56 |
+
|
| 57 |
+
async def process_image(self, image_path: str) -> str:
|
| 58 |
+
logger.info("==[ImprovedOCRAgent] Processing: %s==", image_path)
|
| 59 |
+
|
| 60 |
+
if not os.path.exists(image_path):
|
| 61 |
+
return f"Error: File {image_path} not found."
|
| 62 |
+
|
| 63 |
+
raw_fragments: List[Dict[str, Any]] = []
|
| 64 |
+
|
| 65 |
+
if self.text_model:
|
| 66 |
+
logger.info("[ImprovedOCRAgent] Running PaddleOCR on %s...", image_path)
|
| 67 |
+
result = self.text_model.ocr(image_path)
|
| 68 |
+
logger.info("[ImprovedOCRAgent] PaddleOCR raw result: %s", result)
|
| 69 |
+
|
| 70 |
+
if not result:
|
| 71 |
+
logger.warning("[ImprovedOCRAgent] PaddleOCR returned no results.")
|
| 72 |
+
return ""
|
| 73 |
+
|
| 74 |
+
if isinstance(result[0], dict):
|
| 75 |
+
res_dict = result[0]
|
| 76 |
+
rec_texts = res_dict.get("rec_texts", [])
|
| 77 |
+
rec_scores = res_dict.get("rec_scores", [])
|
| 78 |
+
rec_polys = res_dict.get("rec_polys", [])
|
| 79 |
+
|
| 80 |
+
for i in range(len(rec_texts)):
|
| 81 |
+
text = rec_texts[i]
|
| 82 |
+
bbox = rec_polys[i]
|
| 83 |
+
_ = rec_scores[i]
|
| 84 |
+
|
| 85 |
+
y_top = int(min(p[1] for p in bbox)) if hasattr(bbox, "__iter__") else 0
|
| 86 |
+
|
| 87 |
+
is_math_hint = any(
|
| 88 |
+
c in text for c in ["\\", "^", "_", "{", "}", "=", "+", "-", "*", "/"]
|
| 89 |
+
)
|
| 90 |
+
if is_math_hint and self.math_model:
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
raw_fragments.append({"y": y_top, "content": text, "type": "text"})
|
| 94 |
+
elif isinstance(result[0], list):
|
| 95 |
+
for line in result[0]:
|
| 96 |
+
bbox = line[0]
|
| 97 |
+
text = line[1][0]
|
| 98 |
+
_ = line[1][1]
|
| 99 |
+
|
| 100 |
+
y_top = bbox[0][1]
|
| 101 |
+
raw_fragments.append({"y": y_top, "content": text, "type": "text"})
|
| 102 |
+
|
| 103 |
+
raw_fragments.sort(key=lambda x: x["y"])
|
| 104 |
+
combined_text = "\n".join([f["content"] for f in raw_fragments])
|
| 105 |
+
|
| 106 |
+
logger.info(
|
| 107 |
+
"[ImprovedOCRAgent] Raw OCR output assembled:\n---\n%s\n---", combined_text
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
if not combined_text.strip():
|
| 111 |
+
logger.warning("[ImprovedOCRAgent] No text detected to refine.")
|
| 112 |
+
return ""
|
| 113 |
+
|
| 114 |
+
try:
|
| 115 |
+
logger.info("[ImprovedOCRAgent] Sending to MegaLLM for refinement...")
|
| 116 |
+
refined_text = await asyncio.wait_for(
|
| 117 |
+
self.refine_with_llm(combined_text), timeout=30.0
|
| 118 |
+
)
|
| 119 |
+
return refined_text
|
| 120 |
+
except asyncio.TimeoutError:
|
| 121 |
+
logger.error("[ImprovedOCRAgent] MegaLLM refinement timed out.")
|
| 122 |
+
return combined_text
|
| 123 |
+
except Exception as e:
|
| 124 |
+
logger.error("[ImprovedOCRAgent] MegaLLM refinement failed: %s", e)
|
| 125 |
+
return combined_text
|
| 126 |
+
|
| 127 |
+
async def refine_with_llm(self, text: str) -> str:
|
| 128 |
+
if not text.strip():
|
| 129 |
+
return ""
|
| 130 |
+
|
| 131 |
+
prompt = f"""Bạn là một chuyên gia số hóa tài liệu toán học.
|
| 132 |
+
Dưới đây là kết quả OCR thô từ một trang sách toán Tiếng Việt.
|
| 133 |
+
Kết quả này có thể chứa lỗi chính tả, lỗi định dạng mã LaTeX, hoặc bị ngắt quãng không logic.
|
| 134 |
+
|
| 135 |
+
Nhiệm vụ của bạn:
|
| 136 |
+
1. Sửa lỗi chính tả tiếng Việt.
|
| 137 |
+
2. Đảm bảo các công thức toán học được viết đúng định dạng LaTeX và nằm trong cặp dấu $...$.
|
| 138 |
+
3. Giữ nguyên cấu trúc logic của bài toán.
|
| 139 |
+
4. Trả về nội dung đã được làm sạch dưới dạng Markdown.
|
| 140 |
+
|
| 141 |
+
Nội dung OCR thô:
|
| 142 |
+
---
|
| 143 |
+
{text}
|
| 144 |
+
---
|
| 145 |
+
|
| 146 |
+
Kết quả làm sạch:"""
|
| 147 |
+
|
| 148 |
+
try:
|
| 149 |
+
refined = await self.llm.chat_completions_create(
|
| 150 |
+
messages=[{"role": "user", "content": prompt}],
|
| 151 |
+
temperature=0.1,
|
| 152 |
+
)
|
| 153 |
+
logger.info("[ImprovedOCRAgent] LLM refinement complete.")
|
| 154 |
+
return refined
|
| 155 |
+
except Exception as e:
|
| 156 |
+
logger.error("[ImprovedOCRAgent] LLM refinement failed: %s", e)
|
| 157 |
+
return text
|
| 158 |
+
|
| 159 |
+
async def process_url(self, url: str) -> str:
|
| 160 |
+
import httpx
|
| 161 |
+
|
| 162 |
+
from app.url_utils import sanitize_url
|
| 163 |
+
|
| 164 |
+
url = sanitize_url(url)
|
| 165 |
+
if not url:
|
| 166 |
+
return "Error: Empty image URL after cleanup."
|
| 167 |
+
|
| 168 |
+
async with httpx.AsyncClient() as client:
|
| 169 |
+
resp = await client.get(url)
|
| 170 |
+
if resp.status_code == 200:
|
| 171 |
+
temp_path = "temp_url_image.png"
|
| 172 |
+
with open(temp_path, "wb") as f:
|
| 173 |
+
f.write(resp.content)
|
| 174 |
+
try:
|
| 175 |
+
return await self.process_image(temp_path)
|
| 176 |
+
finally:
|
| 177 |
+
if os.path.exists(temp_path):
|
| 178 |
+
os.remove(temp_path)
|
| 179 |
+
return f"Error: Failed to fetch image from URL {url}"
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class OCRAgent(ImprovedOCRAgent):
|
| 183 |
+
"""Alias for compatibility with existing code."""
|
| 184 |
+
|
| 185 |
+
pass
|
agents/orchestrator.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
from agents.geometry_agent import GeometryAgent
|
| 6 |
+
from agents.knowledge_agent import KnowledgeAgent
|
| 7 |
+
from agents.ocr_agent import OCRAgent
|
| 8 |
+
from agents.parser_agent import ParserAgent
|
| 9 |
+
from agents.renderer_agent import RendererAgent
|
| 10 |
+
from agents.solver_agent import SolverAgent
|
| 11 |
+
from app.logutil import log_step
|
| 12 |
+
from solver.dsl_parser import DSLParser
|
| 13 |
+
from solver.engine import GeometryEngine
|
| 14 |
+
from worker.celery_app import BROKER_URL
|
| 15 |
+
from worker.tasks import render_geometry_video
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
_CLIP = 2000
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _clip(val: Any, n: int = _CLIP) -> str | None:
|
| 23 |
+
if val is None:
|
| 24 |
+
return None
|
| 25 |
+
if isinstance(val, str):
|
| 26 |
+
s = val
|
| 27 |
+
else:
|
| 28 |
+
s = json.dumps(val, ensure_ascii=False, default=str)
|
| 29 |
+
return s if len(s) <= n else s[:n] + "…"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _step_io(step: str, input_val: Any = None, output_val: Any = None) -> None:
|
| 33 |
+
"""Debug: chỉ input/output (đã cắt), tránh dump dài dòng không cần thiết."""
|
| 34 |
+
log_step(step, input=_clip(input_val), output=_clip(output_val))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Orchestrator:
|
| 38 |
+
def __init__(self):
|
| 39 |
+
self.parser_agent = ParserAgent()
|
| 40 |
+
self.geometry_agent = GeometryAgent()
|
| 41 |
+
self.ocr_agent = OCRAgent()
|
| 42 |
+
self.knowledge_agent = KnowledgeAgent()
|
| 43 |
+
self.renderer_agent = RendererAgent()
|
| 44 |
+
self.solver_agent = SolverAgent()
|
| 45 |
+
self.solver_engine = GeometryEngine()
|
| 46 |
+
self.dsl_parser = DSLParser()
|
| 47 |
+
|
| 48 |
+
def _generate_step_description(self, semantic_json: Dict[str, Any], engine_result: Dict[str, Any]) -> str:
|
| 49 |
+
"""Tạo mô tả từng bước vẽ dựa trên kết quả của engine."""
|
| 50 |
+
analysis = semantic_json.get("analysis", "")
|
| 51 |
+
if not analysis:
|
| 52 |
+
analysis = f"Giải bài toán về {semantic_json.get('type', 'hình học')}."
|
| 53 |
+
|
| 54 |
+
steps = ["\n\n**Các bước dựng hình:**"]
|
| 55 |
+
drawing_phases = engine_result.get("drawing_phases", [])
|
| 56 |
+
|
| 57 |
+
for phase in drawing_phases:
|
| 58 |
+
label = phase.get("label", f"Giai đoạn {phase['phase']}")
|
| 59 |
+
points = ", ".join(phase.get("points", []))
|
| 60 |
+
segments = ", ".join([f"{s[0]}{s[1]}" for s in phase.get("segments", [])])
|
| 61 |
+
|
| 62 |
+
step_text = f"- **{label}**:"
|
| 63 |
+
if points:
|
| 64 |
+
step_text += f" Xác định các điểm {points}."
|
| 65 |
+
if segments:
|
| 66 |
+
step_text += f" Vẽ các đoạn thẳng {segments}."
|
| 67 |
+
steps.append(step_text)
|
| 68 |
+
|
| 69 |
+
circles = engine_result.get("circles", [])
|
| 70 |
+
for c in circles:
|
| 71 |
+
steps.append(f"- **Đường tròn**: Vẽ đường tròn tâm {c['center']} bán kính {c['radius']}.")
|
| 72 |
+
|
| 73 |
+
return analysis + "\n".join(steps)
|
| 74 |
+
|
| 75 |
+
async def run(
|
| 76 |
+
self,
|
| 77 |
+
text: str,
|
| 78 |
+
image_url: str = None,
|
| 79 |
+
job_id: str = None,
|
| 80 |
+
session_id: str = None,
|
| 81 |
+
status_callback=None,
|
| 82 |
+
request_video: bool = False,
|
| 83 |
+
history: list = None,
|
| 84 |
+
) -> Dict[str, Any]:
|
| 85 |
+
"""
|
| 86 |
+
Run the full pipeline. Optional history allows context-aware solving.
|
| 87 |
+
"""
|
| 88 |
+
_step_io(
|
| 89 |
+
"orchestrate_start",
|
| 90 |
+
input_val={
|
| 91 |
+
"job_id": job_id,
|
| 92 |
+
"text_len": len(text or ""),
|
| 93 |
+
"image_url": image_url,
|
| 94 |
+
"request_video": request_video,
|
| 95 |
+
"history_len": len(history or []),
|
| 96 |
+
},
|
| 97 |
+
output_val=None,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
if status_callback:
|
| 101 |
+
await status_callback("processing")
|
| 102 |
+
|
| 103 |
+
# 1. Extract context from history (if any)
|
| 104 |
+
previous_context = None
|
| 105 |
+
if history:
|
| 106 |
+
# Look for the last assistant message with geometry data
|
| 107 |
+
for msg in reversed(history):
|
| 108 |
+
if msg.get("role") == "assistant" and msg.get("metadata", {}).get("geometry_dsl"):
|
| 109 |
+
previous_context = {
|
| 110 |
+
"geometry_dsl": msg["metadata"]["geometry_dsl"],
|
| 111 |
+
"coordinates": msg["metadata"].get("coordinates", {}),
|
| 112 |
+
"analysis": msg.get("content", ""),
|
| 113 |
+
}
|
| 114 |
+
break
|
| 115 |
+
|
| 116 |
+
if previous_context:
|
| 117 |
+
_step_io("context_found", input_val=None, output_val={"dsl_len": len(previous_context["geometry_dsl"])})
|
| 118 |
+
|
| 119 |
+
# 2. Gather input text (OCR or direct)
|
| 120 |
+
input_text = text
|
| 121 |
+
if image_url:
|
| 122 |
+
input_text = await self.ocr_agent.process_url(image_url)
|
| 123 |
+
_step_io("step1_ocr", input_val=image_url, output_val=input_text)
|
| 124 |
+
else:
|
| 125 |
+
_step_io("step1_ocr", input_val="(no image)", output_val=text)
|
| 126 |
+
|
| 127 |
+
feedback = None
|
| 128 |
+
MAX_RETRIES = 2
|
| 129 |
+
|
| 130 |
+
for attempt in range(MAX_RETRIES + 1):
|
| 131 |
+
_step_io(
|
| 132 |
+
"attempt",
|
| 133 |
+
input_val=f"{attempt + 1}/{MAX_RETRIES + 1}",
|
| 134 |
+
output_val=None,
|
| 135 |
+
)
|
| 136 |
+
if status_callback:
|
| 137 |
+
await status_callback("solving")
|
| 138 |
+
|
| 139 |
+
# Parser with context
|
| 140 |
+
_step_io("step2_parse", input_val=f"{input_text[:50]}...", output_val=None)
|
| 141 |
+
semantic_json = await self.parser_agent.process(input_text, feedback=feedback, context=previous_context)
|
| 142 |
+
semantic_json["input_text"] = input_text
|
| 143 |
+
_step_io("step2_parse", input_val=None, output_val=semantic_json)
|
| 144 |
+
|
| 145 |
+
# Knowledge augmentation
|
| 146 |
+
_step_io("step3_knowledge", input_val=semantic_json, output_val=None)
|
| 147 |
+
semantic_json = self.knowledge_agent.augment_semantic_data(semantic_json)
|
| 148 |
+
_step_io("step3_knowledge", input_val=None, output_val=semantic_json)
|
| 149 |
+
|
| 150 |
+
# Geometry DSL with context (passing previous DSL to guide generation)
|
| 151 |
+
_step_io("step4_geometry_dsl", input_val=semantic_json, output_val=None)
|
| 152 |
+
dsl_code = await self.geometry_agent.generate_dsl(
|
| 153 |
+
semantic_json,
|
| 154 |
+
previous_dsl=previous_context["geometry_dsl"] if previous_context else None
|
| 155 |
+
)
|
| 156 |
+
_step_io("step4_geometry_dsl", input_val=None, output_val=dsl_code)
|
| 157 |
+
|
| 158 |
+
_step_io("step5_dsl_parse", input_val=dsl_code, output_val=None)
|
| 159 |
+
points, constraints, is_3d = self.dsl_parser.parse(dsl_code)
|
| 160 |
+
_step_io(
|
| 161 |
+
"step5_dsl_parse",
|
| 162 |
+
input_val=None,
|
| 163 |
+
output_val={
|
| 164 |
+
"points": len(points),
|
| 165 |
+
"constraints": len(constraints),
|
| 166 |
+
"is_3d": is_3d,
|
| 167 |
+
},
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
_step_io("step6_solve", input_val=f"{len(points)} pts / {len(constraints)} cons (is_3d={is_3d})", output_val=None)
|
| 171 |
+
import anyio
|
| 172 |
+
engine_result = await anyio.to_thread.run_sync(self.solver_engine.solve, points, constraints, is_3d)
|
| 173 |
+
|
| 174 |
+
if engine_result:
|
| 175 |
+
coordinates = engine_result.get("coordinates")
|
| 176 |
+
_step_io("step6_solve", input_val=None, output_val=coordinates)
|
| 177 |
+
break
|
| 178 |
+
|
| 179 |
+
feedback = "Geometry solver failed to find a valid solution for the given constraints. Parallelism or lengths might be inconsistent."
|
| 180 |
+
_step_io(
|
| 181 |
+
"step6_solve",
|
| 182 |
+
input_val=f"attempt {attempt + 1}",
|
| 183 |
+
output_val=feedback,
|
| 184 |
+
)
|
| 185 |
+
if attempt == MAX_RETRIES:
|
| 186 |
+
_step_io(
|
| 187 |
+
"orchestrate_abort",
|
| 188 |
+
input_val=None,
|
| 189 |
+
output_val="solver_exhausted_retries",
|
| 190 |
+
)
|
| 191 |
+
return {
|
| 192 |
+
"error": "Solver failed after multiple attempts.",
|
| 193 |
+
"last_dsl": dsl_code,
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
status = "success"
|
| 197 |
+
if request_video:
|
| 198 |
+
try:
|
| 199 |
+
result_payload = {
|
| 200 |
+
"geometry_dsl": dsl_code,
|
| 201 |
+
"coordinates": coordinates,
|
| 202 |
+
"polygon_order": engine_result.get("polygon_order", []),
|
| 203 |
+
"drawing_phases": engine_result.get("drawing_phases", []),
|
| 204 |
+
"circles": engine_result.get("circles", []),
|
| 205 |
+
"lines": engine_result.get("lines", []),
|
| 206 |
+
"rays": engine_result.get("rays", []),
|
| 207 |
+
"semantic": semantic_json,
|
| 208 |
+
"semantic_analysis": semantic_json.get("analysis") or semantic_json.get("input_text", ""),
|
| 209 |
+
"session_id": session_id,
|
| 210 |
+
}
|
| 211 |
+
task = render_geometry_video.delay(job_id, result_payload)
|
| 212 |
+
status = "rendering_queued"
|
| 213 |
+
_step_io(
|
| 214 |
+
"step7_video",
|
| 215 |
+
input_val={"job_id": job_id, "broker": BROKER_URL.split("@")[-1] if "@" in BROKER_URL else BROKER_URL},
|
| 216 |
+
output_val={"task_id": str(task.id), "status": status},
|
| 217 |
+
)
|
| 218 |
+
except Exception as e:
|
| 219 |
+
logger.exception("Celery queue failed for job %s", job_id)
|
| 220 |
+
_step_io("step7_video", input_val=job_id, output_val=str(e))
|
| 221 |
+
status = "success"
|
| 222 |
+
else:
|
| 223 |
+
_step_io("step7_video", input_val=request_video, output_val="skipped")
|
| 224 |
+
|
| 225 |
+
_step_io("orchestrate_done", input_val=job_id, output_val=status)
|
| 226 |
+
|
| 227 |
+
# 8. Solution calculation (New in v5.1)
|
| 228 |
+
solution = None
|
| 229 |
+
if engine_result:
|
| 230 |
+
_step_io("step8_solve_math", input_val=semantic_json.get("target_question"), output_val=None)
|
| 231 |
+
solution = await self.solver_agent.solve(semantic_json, engine_result)
|
| 232 |
+
_step_io("step8_solve_math", input_val=None, output_val=solution.get("answer"))
|
| 233 |
+
|
| 234 |
+
final_analysis = self._generate_step_description(semantic_json, engine_result)
|
| 235 |
+
|
| 236 |
+
return {
|
| 237 |
+
"status": status,
|
| 238 |
+
"job_id": job_id,
|
| 239 |
+
"geometry_dsl": dsl_code,
|
| 240 |
+
"coordinates": coordinates,
|
| 241 |
+
"polygon_order": engine_result.get("polygon_order", []),
|
| 242 |
+
"circles": engine_result.get("circles", []),
|
| 243 |
+
"lines": engine_result.get("lines", []),
|
| 244 |
+
"rays": engine_result.get("rays", []),
|
| 245 |
+
"drawing_phases": engine_result.get("drawing_phases", []),
|
| 246 |
+
"semantic": semantic_json,
|
| 247 |
+
"semantic_analysis": final_analysis,
|
| 248 |
+
"solution": solution,
|
| 249 |
+
}
|
agents/parser_agent.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
from openai import AsyncOpenAI
|
| 5 |
+
from typing import Dict, Any
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
from app.url_utils import openai_compatible_api_key, sanitize_env
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
from app.llm_client import get_llm_client
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ParserAgent:
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.llm = get_llm_client()
|
| 20 |
+
|
| 21 |
+
async def process(self, text: str, feedback: str = None, context: Dict[str, Any] = None) -> Dict[str, Any]:
|
| 22 |
+
logger.info(f"==[ParserAgent] Processing input (len={len(text)})==")
|
| 23 |
+
if feedback:
|
| 24 |
+
logger.warning(f"[ParserAgent] Feedback from previous attempt: {feedback}")
|
| 25 |
+
if context:
|
| 26 |
+
logger.info(f"[ParserAgent] Using previous context (dsl_len={len(context.get('geometry_dsl', ''))})")
|
| 27 |
+
|
| 28 |
+
system_prompt = """
|
| 29 |
+
You are a Geometry Parser Agent. Extract geometric entities and constraints from Vietnamese/LaTeX math problem text.
|
| 30 |
+
|
| 31 |
+
=== CONTEXT AWARENESS ===
|
| 32 |
+
If previous context is provided, it means this is a follow-up request.
|
| 33 |
+
- Combine old entities with new ones.
|
| 34 |
+
- Update 'analysis' to reflect the entire problem state.
|
| 35 |
+
|
| 36 |
+
Output ONLY a JSON object with this EXACT structure (no extra keys, no markdown):
|
| 37 |
+
{
|
| 38 |
+
"entities": ["Point A", "Point B", ...],
|
| 39 |
+
"type": "pyramid|prism|sphere|rectangle|triangle|circle|parallelogram|trapezoid|square|rhombus|general",
|
| 40 |
+
"values": {"AB": 5, "SO": 15, "radius": 3},
|
| 41 |
+
"target_question": "Câu hỏi cụ thể cần giải (ví dụ: 'Tính diện tích tam giác ABC'). NẾU KHÔNG CÓ CÂU HỎI THÌ ĐỂ null.",
|
| 42 |
+
"analysis": "Tóm tắt ngắn gọn toàn bộ bài toán sau khi đã cập nhật các yêu cầu mới bằng tiếng Việt."
|
| 43 |
+
}
|
| 44 |
+
Rules:
|
| 45 |
+
- "analysis" MUST be a meaningful and UP-TO-DATE summary of the problem in Vietnamese.
|
| 46 |
+
- "target_question" must be concise.
|
| 47 |
+
- Include midpoints, auxiliary points in "entities" if mentioned.
|
| 48 |
+
- If feedback is provided, correct your previous output accordingly.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
user_content = f"Text: {text}"
|
| 52 |
+
if context:
|
| 53 |
+
user_content = f"PREVIOUS ANALYSIS: {context.get('analysis')}\nNEW REQUEST: {text}"
|
| 54 |
+
|
| 55 |
+
if feedback:
|
| 56 |
+
user_content += f"\nFeedback from previous attempt: {feedback}. Please correct the constraints."
|
| 57 |
+
|
| 58 |
+
logger.debug("[ParserAgent] Calling LLM (Multi-Layer)...")
|
| 59 |
+
raw = await self.llm.chat_completions_create(
|
| 60 |
+
messages=[
|
| 61 |
+
{"role": "system", "content": system_prompt},
|
| 62 |
+
{"role": "user", "content": user_content}
|
| 63 |
+
],
|
| 64 |
+
response_format={"type": "json_object"}
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# Pre-process raw string: extract the JSON block if present
|
| 68 |
+
import re
|
| 69 |
+
clean_raw = raw.strip()
|
| 70 |
+
# Handle potential markdown code blocks
|
| 71 |
+
if clean_raw.startswith("```"):
|
| 72 |
+
import re
|
| 73 |
+
match = re.search(r"```(?:json)?\s*(.*?)\s*```", clean_raw, re.DOTALL)
|
| 74 |
+
if match:
|
| 75 |
+
clean_raw = match.group(1).strip()
|
| 76 |
+
|
| 77 |
+
try:
|
| 78 |
+
result = json.loads(clean_raw)
|
| 79 |
+
except json.JSONDecodeError as e:
|
| 80 |
+
logger.error(f"[ParserAgent] JSON Parse Error: {e}. Attempting regex fallback...")
|
| 81 |
+
import re
|
| 82 |
+
json_match = re.search(r'(\{.*\})', clean_raw, re.DOTALL)
|
| 83 |
+
if json_match:
|
| 84 |
+
try:
|
| 85 |
+
# Handle single quotes if present (common LLM failure)
|
| 86 |
+
json_str = json_match.group(1)
|
| 87 |
+
if "'" in json_str and '"' not in json_str:
|
| 88 |
+
json_str = json_str.replace("'", '"')
|
| 89 |
+
result = json.loads(json_str)
|
| 90 |
+
except:
|
| 91 |
+
result = None
|
| 92 |
+
else:
|
| 93 |
+
result = None
|
| 94 |
+
|
| 95 |
+
if not result:
|
| 96 |
+
# Fallback for critical failure
|
| 97 |
+
result = {
|
| 98 |
+
"entities": [],
|
| 99 |
+
"type": "general",
|
| 100 |
+
"values": {},
|
| 101 |
+
"target_question": None,
|
| 102 |
+
"analysis": text
|
| 103 |
+
}
|
| 104 |
+
logger.info(f"[ParserAgent] LLM response received.")
|
| 105 |
+
logger.debug(f"[ParserAgent] Parsed JSON: {json.dumps(result, ensure_ascii=False, indent=2)}")
|
| 106 |
+
return result
|
agents/renderer_agent.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
import glob
|
| 4 |
+
import string
|
| 5 |
+
from typing import Dict, Any, List
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class RendererAgent:
|
| 9 |
+
"""
|
| 10 |
+
Renderer Agent — generates Manim scripts from geometry data.
|
| 11 |
+
|
| 12 |
+
Drawing happens in phases:
|
| 13 |
+
Phase 1: Main polygon (base shape with correct vertex order)
|
| 14 |
+
Phase 2: Auxiliary points and segments (midpoints, derived segments)
|
| 15 |
+
Phase 3: Labels for all points
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def generate_manim_script(self, data: Dict[str, Any]) -> str:
|
| 19 |
+
coords: Dict[str, List[float]] = data.get("coordinates", {})
|
| 20 |
+
polygon_order: List[str] = data.get("polygon_order", [])
|
| 21 |
+
circles_meta: List[Dict] = data.get("circles", [])
|
| 22 |
+
lines_meta: List[List[str]] = data.get("lines", [])
|
| 23 |
+
rays_meta: List[List[str]] = data.get("rays", [])
|
| 24 |
+
drawing_phases: List[Dict] = data.get("drawing_phases", [])
|
| 25 |
+
semantic: Dict[str, Any] = data.get("semantic", {})
|
| 26 |
+
shape_type = semantic.get("type", "").lower()
|
| 27 |
+
|
| 28 |
+
# ── Detect 3D Context ────────────────────────────────────────────────
|
| 29 |
+
is_3d = False
|
| 30 |
+
for pos in coords.values():
|
| 31 |
+
if len(pos) >= 3 and abs(pos[2]) > 0.001:
|
| 32 |
+
is_3d = True
|
| 33 |
+
break
|
| 34 |
+
if shape_type in ["pyramid", "prism", "sphere"]:
|
| 35 |
+
is_3d = True
|
| 36 |
+
|
| 37 |
+
# ── Fallback: infer polygon_order from coords keys (alphabetical uppercase) ──
|
| 38 |
+
if not polygon_order:
|
| 39 |
+
base = sorted(
|
| 40 |
+
[pid for pid in coords if pid in string.ascii_uppercase],
|
| 41 |
+
key=lambda p: string.ascii_uppercase.index(p)
|
| 42 |
+
)
|
| 43 |
+
polygon_order = base
|
| 44 |
+
|
| 45 |
+
# Separate base points from derived (multi-char or lowercase)
|
| 46 |
+
base_ids = [pid for pid in polygon_order if pid in coords]
|
| 47 |
+
derived_ids = [pid for pid in coords if pid not in polygon_order]
|
| 48 |
+
|
| 49 |
+
scene_base = "ThreeDScene" if is_3d else "MovingCameraScene"
|
| 50 |
+
lines = [
|
| 51 |
+
"from manim import *",
|
| 52 |
+
"",
|
| 53 |
+
f"class GeometryScene({scene_base}):",
|
| 54 |
+
" def construct(self):",
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
if is_3d:
|
| 58 |
+
lines.append(" # 3D Setup")
|
| 59 |
+
lines.append(" self.set_camera_orientation(phi=75*DEGREES, theta=-45*DEGREES)")
|
| 60 |
+
lines.append(" axes = ThreeDAxes(axis_config={'stroke_width': 1})")
|
| 61 |
+
lines.append(" axes.set_opacity(0.3)")
|
| 62 |
+
lines.append(" self.add(axes)")
|
| 63 |
+
lines.append(" self.begin_ambient_camera_rotation(rate=0.1)")
|
| 64 |
+
lines.append("")
|
| 65 |
+
|
| 66 |
+
# ── Declare all dots and labels ───────────────────────────────────────
|
| 67 |
+
for pid, pos in coords.items():
|
| 68 |
+
x, y, z = 0, 0, 0
|
| 69 |
+
if len(pos) >= 1: x = round(pos[0], 4)
|
| 70 |
+
if len(pos) >= 2: y = round(pos[1], 4)
|
| 71 |
+
if len(pos) >= 3: z = round(pos[2], 4)
|
| 72 |
+
|
| 73 |
+
dot_class = "Dot3D" if is_3d else "Dot"
|
| 74 |
+
lines.append(f" p_{pid} = {dot_class}(point=[{x}, {y}, {z}], color=WHITE, radius=0.08)")
|
| 75 |
+
|
| 76 |
+
if is_3d:
|
| 77 |
+
lines.append(
|
| 78 |
+
f" l_{pid} = Text('{pid}', font_size=20, color=WHITE)"
|
| 79 |
+
f".move_to(p_{pid}.get_center() + [0.2, 0.2, 0.2])"
|
| 80 |
+
)
|
| 81 |
+
# Ensure labels follow camera in 3D (fixed orientation)
|
| 82 |
+
lines.append(f" self.add_fixed_orientation_mobjects(l_{pid})")
|
| 83 |
+
else:
|
| 84 |
+
lines.append(
|
| 85 |
+
f" l_{pid} = Text('{pid}', font_size=22, color=WHITE)"
|
| 86 |
+
f".next_to(p_{pid}, UR, buff=0.15)"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# ── 3D Shape Special: Pyramid/Prism Faces ────────────────────────────
|
| 90 |
+
if is_3d and shape_type == "pyramid" and len(base_ids) >= 3:
|
| 91 |
+
# Find apex (usually 'S')
|
| 92 |
+
apex_id = "S" if "S" in coords else derived_ids[0] if derived_ids else None
|
| 93 |
+
if apex_id:
|
| 94 |
+
# Draw base face
|
| 95 |
+
base_pts = ", ".join([f"p_{pid}.get_center()" for pid in base_ids])
|
| 96 |
+
lines.append(f" base_face = Polygon({base_pts}, color=BLUE, fill_opacity=0.1)")
|
| 97 |
+
lines.append(" self.play(Create(base_face), run_time=1.0)")
|
| 98 |
+
|
| 99 |
+
# Draw side faces
|
| 100 |
+
for i in range(len(base_ids)):
|
| 101 |
+
p1 = base_ids[i]
|
| 102 |
+
p2 = base_ids[(i+1)%len(base_ids)]
|
| 103 |
+
face_pts = f"p_{apex_id}.get_center(), p_{p1}.get_center(), p_{p2}.get_center()"
|
| 104 |
+
lines.append(f" side_{i} = Polygon({face_pts}, color=BLUE, stroke_width=1, fill_opacity=0.05)")
|
| 105 |
+
lines.append(f" self.play(Create(side_{i}), run_time=0.5)")
|
| 106 |
+
|
| 107 |
+
# ── Circles ──────────────────────────────────────────────────────────
|
| 108 |
+
for i, c in enumerate(circles_meta):
|
| 109 |
+
center = c["center"]
|
| 110 |
+
r = c["radius"]
|
| 111 |
+
if center in coords:
|
| 112 |
+
cx, cy, cz = 0, 0, 0
|
| 113 |
+
pos = coords[center]
|
| 114 |
+
if len(pos) >= 1: cx = round(pos[0], 4)
|
| 115 |
+
if len(pos) >= 2: cy = round(pos[1], 4)
|
| 116 |
+
if len(pos) >= 3: cz = round(pos[2], 4)
|
| 117 |
+
lines.append(
|
| 118 |
+
f" circle_{i} = Circle(radius={r}, color=BLUE)"
|
| 119 |
+
f".move_to([{cx}, {cy}, {cz}])"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# ── Infinite Lines & Rays ────────────────────────────────────────────
|
| 123 |
+
# (Standard Line works for 3D coordinates in Manim)
|
| 124 |
+
for i, (p1, p2) in enumerate(lines_meta):
|
| 125 |
+
if p1 in coords and p2 in coords:
|
| 126 |
+
lines.append(
|
| 127 |
+
f" line_ext_{i} = Line(p_{p1}.get_center(), p_{p2}.get_center(), color=GRAY_D, stroke_width=2)"
|
| 128 |
+
f".scale(20)"
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
for i, (p1, p2) in enumerate(rays_meta):
|
| 132 |
+
if p1 in coords and p2 in coords:
|
| 133 |
+
lines.append(
|
| 134 |
+
f" ray_{i} = Line(p_{p1}.get_center(), p_{p1}.get_center() + 15 * (p_{p2}.get_center() - p_{p1}.get_center()),"
|
| 135 |
+
f" color=GRAY_C, stroke_width=2)"
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# ── Camera auto-fit group (Only for 2D) ──────────────────────────────
|
| 139 |
+
if not is_3d:
|
| 140 |
+
all_dot_names = [f"p_{pid}" for pid in coords]
|
| 141 |
+
all_names_str = ", ".join(all_dot_names)
|
| 142 |
+
lines.append(f" _all = VGroup({all_names_str})")
|
| 143 |
+
lines.append(" self.camera.frame.set_width(max(_all.width * 2.0, 8))")
|
| 144 |
+
lines.append(" self.camera.frame.move_to(_all)")
|
| 145 |
+
lines.append("")
|
| 146 |
+
|
| 147 |
+
# ── Phase 1: Base polygon ─────────────────────────────────────────────
|
| 148 |
+
if len(base_ids) >= 3:
|
| 149 |
+
pts_str = ", ".join([f"p_{pid}.get_center()" for pid in base_ids])
|
| 150 |
+
lines.append(f" poly = Polygon({pts_str}, color=BLUE, fill_color=BLUE, fill_opacity=0.15)")
|
| 151 |
+
lines.append(" self.play(Create(poly), run_time=1.5)")
|
| 152 |
+
elif len(base_ids) == 2:
|
| 153 |
+
p1, p2 = base_ids
|
| 154 |
+
lines.append(f" base_line = Line(p_{p1}.get_center(), p_{p2}.get_center(), color=BLUE)")
|
| 155 |
+
lines.append(" self.play(Create(base_line), run_time=1.0)")
|
| 156 |
+
|
| 157 |
+
# Draw base points
|
| 158 |
+
if base_ids:
|
| 159 |
+
base_dots_str = ", ".join([f"p_{pid}" for pid in base_ids])
|
| 160 |
+
lines.append(f" self.play(FadeIn(VGroup({base_dots_str})), run_time=0.5)")
|
| 161 |
+
lines.append(" self.wait(0.5)")
|
| 162 |
+
|
| 163 |
+
# ── Phase 2: Auxiliary points and segments ────────────────────────────
|
| 164 |
+
if derived_ids:
|
| 165 |
+
derived_dots_str = ", ".join([f"p_{pid}" for pid in derived_ids])
|
| 166 |
+
lines.append(f" self.play(FadeIn(VGroup({derived_dots_str})), run_time=0.8)")
|
| 167 |
+
|
| 168 |
+
# Segments from drawing_phases
|
| 169 |
+
segment_lines = []
|
| 170 |
+
for phase in drawing_phases:
|
| 171 |
+
if phase.get("phase") == 2:
|
| 172 |
+
for seg in phase.get("segments", []):
|
| 173 |
+
if len(seg) == 2 and seg[0] in coords and seg[1] in coords:
|
| 174 |
+
p1, p2 = seg[0], seg[1]
|
| 175 |
+
seg_var = f"seg_{p1}_{p2}"
|
| 176 |
+
lines.append(
|
| 177 |
+
f" {seg_var} = Line(p_{p1}.get_center(), p_{p2}.get_center(),"
|
| 178 |
+
f" color=YELLOW)"
|
| 179 |
+
)
|
| 180 |
+
segment_lines.append(seg_var)
|
| 181 |
+
|
| 182 |
+
if segment_lines:
|
| 183 |
+
segs_str = ", ".join([f"Create({sv})" for sv in segment_lines])
|
| 184 |
+
lines.append(f" self.play({segs_str}, run_time=1.2)")
|
| 185 |
+
|
| 186 |
+
if derived_ids or segment_lines:
|
| 187 |
+
lines.append(" self.wait(0.5)")
|
| 188 |
+
|
| 189 |
+
# ── Phase 3: All labels ───────────────────────────────────────────────
|
| 190 |
+
all_labels_str = ", ".join([f"l_{pid}" for pid in coords])
|
| 191 |
+
lines.append(f" self.play(FadeIn(VGroup({all_labels_str})), run_time=0.8)")
|
| 192 |
+
|
| 193 |
+
# ── Circles phase ─────────────────────────────────────────────────────
|
| 194 |
+
for i in range(len(circles_meta)):
|
| 195 |
+
lines.append(f" self.play(Create(circle_{i}), run_time=1.5)")
|
| 196 |
+
|
| 197 |
+
# ── Lines & Rays phase ────────────────────────────────────────────────
|
| 198 |
+
if lines_meta or rays_meta:
|
| 199 |
+
lr_anims = []
|
| 200 |
+
for i in range(len(lines_meta)): lr_anims.append(f"Create(line_ext_{i})")
|
| 201 |
+
for i in range(len(rays_meta)): lr_anims.append(f"Create(ray_{i})")
|
| 202 |
+
lines.append(f" self.play({', '.join(lr_anims)}, run_time=1.5)")
|
| 203 |
+
|
| 204 |
+
lines.append(" self.wait(2)")
|
| 205 |
+
|
| 206 |
+
return "\n".join(lines)
|
| 207 |
+
|
| 208 |
+
def run_manim(self, script_content: str, job_id: str) -> str:
|
| 209 |
+
import subprocess
|
| 210 |
+
import os
|
| 211 |
+
import glob
|
| 212 |
+
|
| 213 |
+
script_file = f"{job_id}.py"
|
| 214 |
+
with open(script_file, "w") as f:
|
| 215 |
+
f.write(script_content)
|
| 216 |
+
|
| 217 |
+
try:
|
| 218 |
+
if os.getenv("MOCK_VIDEO") == "true":
|
| 219 |
+
logger.info(f"MOCK_VIDEO is true. Skipping Manim for job {job_id}")
|
| 220 |
+
# Create a dummy file if needed, or just return a path that exists
|
| 221 |
+
dummy_path = f"videos/{job_id}.mp4"
|
| 222 |
+
os.makedirs("videos", exist_ok=True)
|
| 223 |
+
with open(dummy_path, "wb") as f:
|
| 224 |
+
f.write(b"dummy video content")
|
| 225 |
+
return dummy_path
|
| 226 |
+
|
| 227 |
+
print(f"Running Manim for job {job_id}...")
|
| 228 |
+
result = subprocess.run(
|
| 229 |
+
["manim", "-ql", "--media_dir", ".", "-o", f"{job_id}.mp4", script_file, "GeometryScene"],
|
| 230 |
+
capture_output=True,
|
| 231 |
+
text=True,
|
| 232 |
+
)
|
| 233 |
+
print(f"Manim STDOUT: {result.stdout}")
|
| 234 |
+
print(f"Manim STDERR: {result.stderr}")
|
| 235 |
+
|
| 236 |
+
for pattern in [f"**/videos/**/{job_id}.mp4", f"**/{job_id}*.mp4"]:
|
| 237 |
+
found = glob.glob(pattern, recursive=True)
|
| 238 |
+
if found:
|
| 239 |
+
print(f"Manim Success: Found {found[0]}")
|
| 240 |
+
return found[0]
|
| 241 |
+
|
| 242 |
+
print(f"Manim file not found for job {job_id}")
|
| 243 |
+
return ""
|
| 244 |
+
except Exception as e:
|
| 245 |
+
print(f"Manim Execution Error: {e}")
|
| 246 |
+
return ""
|
| 247 |
+
finally:
|
| 248 |
+
if os.path.exists(script_file):
|
| 249 |
+
os.remove(script_file)
|
agents/solver_agent.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
import sympy as sp
|
| 4 |
+
from typing import Dict, Any, List
|
| 5 |
+
from app.llm_client import get_llm_client
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
class SolverAgent:
|
| 10 |
+
def __init__(self):
|
| 11 |
+
self.llm = get_llm_client()
|
| 12 |
+
|
| 13 |
+
async def solve(self, semantic_data: Dict[str, Any], engine_result: Dict[str, Any]) -> Dict[str, Any]:
|
| 14 |
+
"""
|
| 15 |
+
Solves the geometric problem based on coordinates and the target question.
|
| 16 |
+
Returns a 'solution' dictionary with answer, steps, and symbolic_expression.
|
| 17 |
+
"""
|
| 18 |
+
target_question = semantic_data.get("target_question")
|
| 19 |
+
if not target_question:
|
| 20 |
+
# If no question, just return an empty solution structure
|
| 21 |
+
return {
|
| 22 |
+
"answer": None,
|
| 23 |
+
"steps": [],
|
| 24 |
+
"symbolic_expression": None
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
logger.info(f"==[SolverAgent] Solving for: '{target_question}'==")
|
| 28 |
+
|
| 29 |
+
input_text = semantic_data.get("input_text", "")
|
| 30 |
+
coordinates = engine_result.get("coordinates", {})
|
| 31 |
+
|
| 32 |
+
# We provide the coordinates and semantic context to the LLM to help it reason.
|
| 33 |
+
# The LLM is tasked with generating the solution structure directly.
|
| 34 |
+
|
| 35 |
+
system_prompt = """
|
| 36 |
+
You are a Geometry Solver Agent. Your goal is to provide a step-by-step solution for a specific geometric question.
|
| 37 |
+
|
| 38 |
+
=== DATA PROVIDED ===
|
| 39 |
+
1. Target Question: The specific question to answer.
|
| 40 |
+
2. Geometry Data: Entities and values extracted from the problem.
|
| 41 |
+
3. Coordinates: Calculated coordinates for all points.
|
| 42 |
+
|
| 43 |
+
=== REQUIREMENTS ===
|
| 44 |
+
- Provide the solution in the SAME LANGUAGE as the user's input.
|
| 45 |
+
- Use SymPy concepts if appropriate.
|
| 46 |
+
- Steps should be clear, concise, and logical.
|
| 47 |
+
- The final answer should be numerically or symbolically accurate based on the coordinates and geometric properties.
|
| 48 |
+
- For geometric proofs (e.g., "Is AB perpendicular to AC?"), explain the reasoning based on the data.
|
| 49 |
+
|
| 50 |
+
Output ONLY a JSON object with this structure:
|
| 51 |
+
{
|
| 52 |
+
"answer": "Chuỗi văn bản kết quả cuối cùng (kèm đơn vị nếu có)",
|
| 53 |
+
"steps": [
|
| 54 |
+
"Bước 1: ...",
|
| 55 |
+
"Bước 2: ...",
|
| 56 |
+
...
|
| 57 |
+
],
|
| 58 |
+
"symbolic_expression": "Biểu thức toán học rút gọn (LaTeX format optional)"
|
| 59 |
+
}
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
user_content = f"""
|
| 63 |
+
INPUT_TEXT: {input_text}
|
| 64 |
+
TARGET_QUESTION: {target_question}
|
| 65 |
+
SEMANTIC_DATA: {json.dumps(semantic_data, ensure_ascii=False)}
|
| 66 |
+
COORDINATES: {json.dumps(coordinates)}
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
logger.debug("[SolverAgent] Requesting solution from LLM...")
|
| 70 |
+
try:
|
| 71 |
+
raw = await self.llm.chat_completions_create(
|
| 72 |
+
messages=[
|
| 73 |
+
{"role": "system", "content": system_prompt},
|
| 74 |
+
{"role": "user", "content": user_content}
|
| 75 |
+
],
|
| 76 |
+
response_format={"type": "json_object"}
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
clean_raw = raw.strip()
|
| 80 |
+
# Handle potential markdown code blocks if the response_format wasn't strictly honored
|
| 81 |
+
if clean_raw.startswith("```"):
|
| 82 |
+
import re
|
| 83 |
+
match = re.search(r"```(?:json)?\s*(.*?)\s*```", clean_raw, re.DOTALL)
|
| 84 |
+
if match:
|
| 85 |
+
clean_raw = match.group(1).strip()
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
solution = json.loads(clean_raw)
|
| 89 |
+
except json.JSONDecodeError:
|
| 90 |
+
# Last resort: try to find anything between { and }
|
| 91 |
+
import re
|
| 92 |
+
json_match = re.search(r'(\{.*\})', clean_raw, re.DOTALL)
|
| 93 |
+
if json_match:
|
| 94 |
+
solution = json.loads(json_match.group(1))
|
| 95 |
+
else:
|
| 96 |
+
raise
|
| 97 |
+
|
| 98 |
+
logger.info("[SolverAgent] Solution generated successfully.")
|
| 99 |
+
return solution
|
| 100 |
+
except Exception as e:
|
| 101 |
+
logger.error(f"[SolverAgent] Error generating solution: {e}")
|
| 102 |
+
logger.debug(f"[SolverAgent] Raw LLM output was: \n{raw if 'raw' in locals() else 'N/A'}")
|
| 103 |
+
return {
|
| 104 |
+
"answer": "Không thể tính toán lời giải tại thời điểm này.",
|
| 105 |
+
"steps": ["Đã xảy ra lỗi trong quá trình xử lý lời giải."],
|
| 106 |
+
"symbolic_expression": None
|
| 107 |
+
}
|
agents/torch_ultralytics_compat.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch 2.6+ defaults weights_only=True; Ultralytics YOLO .pt checkpoints unpickle full nn graphs (trusted official weights)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import functools
|
| 6 |
+
|
| 7 |
+
_torch_load_patched = False
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def allow_ultralytics_weights() -> None:
|
| 11 |
+
"""
|
| 12 |
+
Official yolov8n.pt is a trusted checkpoint. PyTorch 2.6+ safe unpickling would require
|
| 13 |
+
allowlisting many torch.nn globals; loading with weights_only=False matches Ultralytics
|
| 14 |
+
upstream behavior for local .pt files.
|
| 15 |
+
"""
|
| 16 |
+
global _torch_load_patched
|
| 17 |
+
if _torch_load_patched:
|
| 18 |
+
return
|
| 19 |
+
try:
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
_orig = torch.load
|
| 23 |
+
|
| 24 |
+
@functools.wraps(_orig)
|
| 25 |
+
def _load(*args, **kwargs):
|
| 26 |
+
if "weights_only" not in kwargs:
|
| 27 |
+
kwargs["weights_only"] = False
|
| 28 |
+
return _orig(*args, **kwargs)
|
| 29 |
+
|
| 30 |
+
torch.load = _load
|
| 31 |
+
_torch_load_patched = True
|
| 32 |
+
except Exception:
|
| 33 |
+
pass
|
app/dependencies.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import HTTPException, Header
|
| 2 |
+
|
| 3 |
+
from app.supabase_client import get_supabase, get_supabase_for_user_jwt
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
async def get_current_user_id(authorization: str = Header(...)):
|
| 7 |
+
"""
|
| 8 |
+
Authenticate user using Supabase JWT.
|
| 9 |
+
Expected Header: Authorization: Bearer <token>
|
| 10 |
+
"""
|
| 11 |
+
import os
|
| 12 |
+
if os.getenv("ALLOW_TEST_BYPASS") == "true" and authorization.startswith("Test "):
|
| 13 |
+
return authorization.split(" ")[1]
|
| 14 |
+
|
| 15 |
+
if not authorization or not authorization.startswith("Bearer "):
|
| 16 |
+
raise HTTPException(
|
| 17 |
+
status_code=401,
|
| 18 |
+
detail="Authorization header missing or invalid. Use 'Bearer <token>'",
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
token = authorization.split(" ")[1]
|
| 22 |
+
supabase = get_supabase()
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
user_response = supabase.auth.get_user(token)
|
| 26 |
+
if not user_response or not user_response.user:
|
| 27 |
+
raise HTTPException(status_code=401, detail="Invalid session or token.")
|
| 28 |
+
|
| 29 |
+
return user_response.user.id
|
| 30 |
+
except HTTPException:
|
| 31 |
+
raise
|
| 32 |
+
except Exception as e:
|
| 33 |
+
raise HTTPException(status_code=401, detail=f"Authentication failed: {str(e)}")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
async def get_authenticated_supabase(authorization: str = Header(...)):
|
| 37 |
+
"""
|
| 38 |
+
Supabase client that carries the user's JWT (anon key + Authorization header).
|
| 39 |
+
Use for routes that should respect Row Level Security; pair with app logic as needed.
|
| 40 |
+
"""
|
| 41 |
+
if not authorization or not authorization.startswith("Bearer "):
|
| 42 |
+
raise HTTPException(
|
| 43 |
+
status_code=401,
|
| 44 |
+
detail="Authorization header missing or invalid. Use 'Bearer <token>'",
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
token = authorization.split(" ")[1]
|
| 48 |
+
supabase = get_supabase()
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
user_response = supabase.auth.get_user(token)
|
| 52 |
+
if not user_response or not user_response.user:
|
| 53 |
+
raise HTTPException(status_code=401, detail="Invalid session or token.")
|
| 54 |
+
except HTTPException:
|
| 55 |
+
raise
|
| 56 |
+
except Exception as e:
|
| 57 |
+
raise HTTPException(status_code=401, detail=f"Authentication failed: {str(e)}")
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
return get_supabase_for_user_jwt(token)
|
| 61 |
+
except RuntimeError as e:
|
| 62 |
+
raise HTTPException(status_code=503, detail=str(e))
|
app/errors.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Map exceptions to short, user-visible messages (avoid leaking HTML bodies from 404 proxies)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _looks_like_html(text: str) -> bool:
|
| 11 |
+
t = text.lstrip()[:500].lower()
|
| 12 |
+
return t.startswith("<!doctype") or t.startswith("<html") or "<html" in t[:200]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def format_error_for_user(exc: BaseException) -> str:
|
| 16 |
+
"""
|
| 17 |
+
Produce a safe message for chat/UI. Full detail stays in server logs via logger.exception.
|
| 18 |
+
"""
|
| 19 |
+
# httpx: wrong URL often returns 404 HTML; don't show body
|
| 20 |
+
try:
|
| 21 |
+
import httpx
|
| 22 |
+
|
| 23 |
+
if isinstance(exc, httpx.HTTPStatusError):
|
| 24 |
+
req = exc.request
|
| 25 |
+
code = exc.response.status_code
|
| 26 |
+
url_hint = ""
|
| 27 |
+
try:
|
| 28 |
+
url_hint = str(req.url.host) if req and req.url else ""
|
| 29 |
+
except Exception:
|
| 30 |
+
pass
|
| 31 |
+
logger.warning(
|
| 32 |
+
"HTTPStatusError %s for %s (response not shown to user)",
|
| 33 |
+
code,
|
| 34 |
+
url_hint or "?",
|
| 35 |
+
)
|
| 36 |
+
return (
|
| 37 |
+
"Kiểm tra URL API, khóa bí mật và biến môi trường (OpenRouter/Supabase/Redis)."
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
if isinstance(exc, httpx.RequestError):
|
| 41 |
+
return "Không kết nối được tới dịch vụ ngoài (mạng hoặc URL sai)."
|
| 42 |
+
except ImportError:
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
raw = str(exc).strip()
|
| 46 |
+
if not raw:
|
| 47 |
+
return "Đã xảy ra lỗi không xác định."
|
| 48 |
+
|
| 49 |
+
if _looks_like_html(raw):
|
| 50 |
+
logger.warning("Suppressed HTML error body from user-facing message")
|
| 51 |
+
return (
|
| 52 |
+
"Dịch vụ trả về trang lỗi (thường là URL API sai hoặc endpoint không tồn tại — HTTP 404). "
|
| 53 |
+
"Kiểm tra OPENROUTER_MODEL và khóa API trên server."
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
if len(raw) > 800:
|
| 57 |
+
return raw[:800] + "…"
|
| 58 |
+
|
| 59 |
+
return raw
|
app/llm_client.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import asyncio
|
| 4 |
+
import logging
|
| 5 |
+
from openai import AsyncOpenAI
|
| 6 |
+
from typing import List, Dict, Any, Optional
|
| 7 |
+
from app.url_utils import openai_compatible_api_key, sanitize_env
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class MultiLayerLLMClient:
|
| 12 |
+
def __init__(self):
|
| 13 |
+
# 1. OpenRouter (Primary with Rotation)
|
| 14 |
+
self.openrouter_model = os.getenv("OPENROUTER_MODEL", "google/gemini-2.0-flash-001")
|
| 15 |
+
self.keys = []
|
| 16 |
+
for i in range(1, 6):
|
| 17 |
+
key = os.getenv(f"OPENROUTER_API_KEY_{i}")
|
| 18 |
+
if key:
|
| 19 |
+
self.keys.append(key)
|
| 20 |
+
|
| 21 |
+
if not self.keys:
|
| 22 |
+
# Fallback to single key if exists (legacy)
|
| 23 |
+
single_key = os.getenv("OPENROUTER_API_KEY")
|
| 24 |
+
if single_key:
|
| 25 |
+
self.keys.append(single_key)
|
| 26 |
+
|
| 27 |
+
self.clients = [
|
| 28 |
+
AsyncOpenAI(
|
| 29 |
+
api_key=openai_compatible_api_key(k),
|
| 30 |
+
base_url="https://openrouter.ai/api/v1",
|
| 31 |
+
timeout=120.0,
|
| 32 |
+
default_headers={
|
| 33 |
+
"HTTP-Referer": "https://mathsolver.ai",
|
| 34 |
+
"X-Title": "MathSolver Backend",
|
| 35 |
+
}
|
| 36 |
+
) for k in self.keys
|
| 37 |
+
]
|
| 38 |
+
self.current_index = 0
|
| 39 |
+
|
| 40 |
+
async def chat_completions_create(
|
| 41 |
+
self,
|
| 42 |
+
messages: List[Dict[str, str]],
|
| 43 |
+
response_format: Optional[Dict[str, str]] = None,
|
| 44 |
+
**kwargs
|
| 45 |
+
) -> str:
|
| 46 |
+
"""
|
| 47 |
+
Rotates through OpenRouter API keys on every attempt (success or failure).
|
| 48 |
+
Tries up to 2 retries (total 3 attempts), with 0s delay.
|
| 49 |
+
"""
|
| 50 |
+
MAX_RETRIES = 2
|
| 51 |
+
RETRY_DELAY = 0 # seconds
|
| 52 |
+
|
| 53 |
+
if not self.clients:
|
| 54 |
+
logger.error("[LLM] No OpenRouter API keys found.")
|
| 55 |
+
raise ValueError("No API keys configured.")
|
| 56 |
+
|
| 57 |
+
for attempt in range(1, MAX_RETRIES + 2): # Up to 3 attempts total
|
| 58 |
+
client = self.clients[self.current_index]
|
| 59 |
+
key_id = self.current_index + 1
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
logger.info(f"[LLM] Attempt {attempt}/{MAX_RETRIES + 1} using Key #{key_id} ({self.openrouter_model})...")
|
| 63 |
+
response = await client.chat.completions.create(
|
| 64 |
+
model=self.openrouter_model,
|
| 65 |
+
messages=messages,
|
| 66 |
+
response_format=response_format,
|
| 67 |
+
**kwargs
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
if not response or not getattr(response, "choices", None):
|
| 71 |
+
raise ValueError("Invalid response structure from OpenRouter")
|
| 72 |
+
|
| 73 |
+
content = response.choices[0].message.content
|
| 74 |
+
if content:
|
| 75 |
+
logger.info(f"[LLM] SUCCESS on attempt {attempt} (Key #{key_id}).")
|
| 76 |
+
# Luôn xoay sang key tiếp theo sau khi thành công để chuẩn bị cho request tới
|
| 77 |
+
self.current_index = (self.current_index + 1) % len(self.clients)
|
| 78 |
+
return content
|
| 79 |
+
|
| 80 |
+
raise ValueError("Empty content from OpenRouter")
|
| 81 |
+
|
| 82 |
+
except Exception as e:
|
| 83 |
+
err_msg = f"{type(e).__name__}: {str(e)}"
|
| 84 |
+
logger.warning(f"[LLM] FAILED on attempt {attempt} (Key #{key_id}): {err_msg}")
|
| 85 |
+
|
| 86 |
+
# Xoay key kể cả khi thất bại để attempt tiếp theo dùng key khác
|
| 87 |
+
old_index = self.current_index
|
| 88 |
+
self.current_index = (self.current_index + 1) % len(self.clients)
|
| 89 |
+
|
| 90 |
+
if attempt <= MAX_RETRIES:
|
| 91 |
+
logger.info(f"[LLM] Switching from Key #{old_index + 1} to #{self.current_index + 1}. Retrying in {RETRY_DELAY}s...")
|
| 92 |
+
await asyncio.sleep(RETRY_DELAY)
|
| 93 |
+
else:
|
| 94 |
+
logger.error(f"[LLM] FINAL FAILURE after {attempt} attempts.")
|
| 95 |
+
raise e
|
| 96 |
+
|
| 97 |
+
# Global instance for easy reuse (singleton-ish)
|
| 98 |
+
_llm_client = None
|
| 99 |
+
|
| 100 |
+
def get_llm_client() -> MultiLayerLLMClient:
|
| 101 |
+
global _llm_client
|
| 102 |
+
if _llm_client is None:
|
| 103 |
+
_llm_client = MultiLayerLLMClient()
|
| 104 |
+
return _llm_client
|
app/logging_setup.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Logging theo một biến LOG_LEVEL: debug | info | warning | error."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
from typing import Final
|
| 8 |
+
|
| 9 |
+
_SETUP_DONE = False
|
| 10 |
+
|
| 11 |
+
PIPELINE_LOGGER_NAME: Final = "app.pipeline"
|
| 12 |
+
CACHE_LOGGER_NAME: Final = "app.cache"
|
| 13 |
+
STEPS_LOGGER_NAME: Final = "app.steps"
|
| 14 |
+
ACCESS_LOGGER_NAME: Final = "app.access"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _normalize_level() -> str:
|
| 18 |
+
raw = os.getenv("LOG_LEVEL", "info").strip().lower()
|
| 19 |
+
if raw in ("debug", "info", "warning", "error"):
|
| 20 |
+
return raw
|
| 21 |
+
return "info"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def setup_application_logging() -> None:
|
| 25 |
+
"""Idempotent; gọi khi khởi động process (uvicorn, celery, worker_health)."""
|
| 26 |
+
global _SETUP_DONE
|
| 27 |
+
if _SETUP_DONE:
|
| 28 |
+
return
|
| 29 |
+
_SETUP_DONE = True
|
| 30 |
+
|
| 31 |
+
mode = _normalize_level()
|
| 32 |
+
|
| 33 |
+
level_map = {
|
| 34 |
+
"debug": logging.DEBUG,
|
| 35 |
+
"info": logging.INFO,
|
| 36 |
+
"warning": logging.WARNING,
|
| 37 |
+
"error": logging.ERROR,
|
| 38 |
+
}
|
| 39 |
+
root_level = level_map[mode]
|
| 40 |
+
|
| 41 |
+
fmt_named = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
|
| 42 |
+
fmt_short = "%(asctime)s | %(levelname)-8s | %(message)s"
|
| 43 |
+
|
| 44 |
+
logging.basicConfig(
|
| 45 |
+
level=root_level,
|
| 46 |
+
format=fmt_named if mode == "debug" else fmt_short,
|
| 47 |
+
datefmt="%H:%M:%S",
|
| 48 |
+
force=True,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
logging.getLogger("httpx").setLevel(logging.WARNING)
|
| 52 |
+
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
| 53 |
+
logging.getLogger("openai").setLevel(logging.WARNING)
|
| 54 |
+
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
| 55 |
+
logging.getLogger("uvicorn.error").setLevel(logging.INFO)
|
| 56 |
+
# HTTP/2 stack (httpx/httpcore) — khi LOG_LEVEL=debug root=DEBUG sẽ tràn log hpack; không cần cho debug app
|
| 57 |
+
for _name in ("hpack", "h2", "hyperframe", "urllib3"):
|
| 58 |
+
logging.getLogger(_name).setLevel(logging.WARNING)
|
| 59 |
+
|
| 60 |
+
if mode == "debug":
|
| 61 |
+
logging.getLogger("agents").setLevel(logging.DEBUG)
|
| 62 |
+
logging.getLogger("solver").setLevel(logging.DEBUG)
|
| 63 |
+
logging.getLogger("app").setLevel(logging.DEBUG)
|
| 64 |
+
logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.DEBUG)
|
| 65 |
+
logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.DEBUG)
|
| 66 |
+
logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.INFO)
|
| 67 |
+
logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.INFO)
|
| 68 |
+
logging.getLogger("app.main").setLevel(logging.INFO)
|
| 69 |
+
logging.getLogger("worker").setLevel(logging.INFO)
|
| 70 |
+
elif mode == "info":
|
| 71 |
+
# Chỉ HTTP access (app.access) + startup; ẩn chi tiết agents/orchestrator/pipeline SUCCESS
|
| 72 |
+
logging.getLogger("agents").setLevel(logging.INFO)
|
| 73 |
+
logging.getLogger("solver").setLevel(logging.WARNING)
|
| 74 |
+
logging.getLogger("app").setLevel(logging.INFO)
|
| 75 |
+
logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.WARNING)
|
| 76 |
+
logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.WARNING)
|
| 77 |
+
logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.WARNING)
|
| 78 |
+
logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.INFO)
|
| 79 |
+
logging.getLogger("app.main").setLevel(logging.INFO)
|
| 80 |
+
logging.getLogger("worker").setLevel(logging.WARNING)
|
| 81 |
+
elif mode == "warning":
|
| 82 |
+
logging.getLogger("agents").setLevel(logging.WARNING)
|
| 83 |
+
logging.getLogger("solver").setLevel(logging.WARNING)
|
| 84 |
+
logging.getLogger("app.routers").setLevel(logging.WARNING)
|
| 85 |
+
logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.WARNING)
|
| 86 |
+
logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.WARNING)
|
| 87 |
+
logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.WARNING)
|
| 88 |
+
logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.WARNING)
|
| 89 |
+
logging.getLogger("app.main").setLevel(logging.WARNING)
|
| 90 |
+
logging.getLogger("worker").setLevel(logging.WARNING)
|
| 91 |
+
else: # error
|
| 92 |
+
logging.getLogger("agents").setLevel(logging.ERROR)
|
| 93 |
+
logging.getLogger("solver").setLevel(logging.ERROR)
|
| 94 |
+
logging.getLogger("app.routers").setLevel(logging.ERROR)
|
| 95 |
+
logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.ERROR)
|
| 96 |
+
logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.ERROR)
|
| 97 |
+
logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.ERROR)
|
| 98 |
+
logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.ERROR)
|
| 99 |
+
logging.getLogger("app.main").setLevel(logging.ERROR)
|
| 100 |
+
logging.getLogger("worker").setLevel(logging.ERROR)
|
| 101 |
+
|
| 102 |
+
logging.getLogger(__name__).debug(
|
| 103 |
+
"LOG_LEVEL=%s root=%s", mode, logging.getLevelName(root_level)
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_log_level() -> str:
|
| 108 |
+
return _normalize_level()
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def is_debug_level() -> bool:
|
| 112 |
+
return _normalize_level() == "debug"
|
app/logutil.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""log_step (debug), pipeline (debug), access log ở middleware."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
from app.logging_setup import PIPELINE_LOGGER_NAME, STEPS_LOGGER_NAME
|
| 11 |
+
|
| 12 |
+
_pipeline = logging.getLogger(PIPELINE_LOGGER_NAME)
|
| 13 |
+
_steps = logging.getLogger(STEPS_LOGGER_NAME)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def is_debug_mode() -> bool:
|
| 17 |
+
"""Chi tiết từng bước chỉ khi LOG_LEVEL=debug."""
|
| 18 |
+
return os.getenv("LOG_LEVEL", "info").strip().lower() == "debug"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _truncate(val: Any, max_len: int = 2000) -> Any:
|
| 22 |
+
if val is None:
|
| 23 |
+
return None
|
| 24 |
+
if isinstance(val, (int, float, bool)):
|
| 25 |
+
return val
|
| 26 |
+
s = str(val)
|
| 27 |
+
if len(s) > max_len:
|
| 28 |
+
return s[:max_len] + f"... (+{len(s) - max_len} chars)"
|
| 29 |
+
return s
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def log_step(step: str, **fields: Any) -> None:
|
| 33 |
+
"""Chỉ khi LOG_LEVEL=debug: DB / cache / orchestrator."""
|
| 34 |
+
if not is_debug_mode():
|
| 35 |
+
return
|
| 36 |
+
safe = {k: _truncate(v) for k, v in fields.items()}
|
| 37 |
+
try:
|
| 38 |
+
payload = json.dumps(safe, ensure_ascii=False, default=str)
|
| 39 |
+
except Exception:
|
| 40 |
+
payload = str(safe)
|
| 41 |
+
_steps.debug("[step:%s] %s", step, payload)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def log_pipeline_success(operation: str, **fields: Any) -> None:
|
| 45 |
+
"""Chỉ hiện khi debug (pipeline SUCCESS không dùng ở info — đã có app.access)."""
|
| 46 |
+
if not is_debug_mode():
|
| 47 |
+
return
|
| 48 |
+
safe = {k: _truncate(v, 500) for k, v in fields.items()}
|
| 49 |
+
_pipeline.info(
|
| 50 |
+
"SUCCESS %s %s",
|
| 51 |
+
operation,
|
| 52 |
+
json.dumps(safe, ensure_ascii=False, default=str),
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def log_pipeline_failure(operation: str, error: str | None = None, **fields: Any) -> None:
|
| 57 |
+
"""Thất bại pipeline: luôn dùng WARNING để vẫn thấy khi LOG_LEVEL=warning."""
|
| 58 |
+
if is_debug_mode():
|
| 59 |
+
safe = {k: _truncate(v, 500) for k, v in fields.items()}
|
| 60 |
+
_pipeline.warning(
|
| 61 |
+
"FAIL %s err=%s %s",
|
| 62 |
+
operation,
|
| 63 |
+
_truncate(error, 300),
|
| 64 |
+
json.dumps(safe, ensure_ascii=False, default=str),
|
| 65 |
+
)
|
| 66 |
+
else:
|
| 67 |
+
_pipeline.warning("FAIL %s", operation)
|
app/main.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
import uuid
|
| 7 |
+
import warnings
|
| 8 |
+
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
from fastapi import FastAPI, File, HTTPException, UploadFile
|
| 11 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 12 |
+
from starlette.requests import Request
|
| 13 |
+
|
| 14 |
+
load_dotenv()
|
| 15 |
+
|
| 16 |
+
from app.runtime_env import apply_runtime_env_defaults
|
| 17 |
+
|
| 18 |
+
apply_runtime_env_defaults()
|
| 19 |
+
|
| 20 |
+
os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
|
| 21 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
|
| 22 |
+
warnings.filterwarnings("ignore", category=UserWarning, module="albumentations")
|
| 23 |
+
|
| 24 |
+
from app.logging_setup import ACCESS_LOGGER_NAME, get_log_level, setup_application_logging
|
| 25 |
+
|
| 26 |
+
setup_application_logging()
|
| 27 |
+
|
| 28 |
+
# Routers (after logging)
|
| 29 |
+
from app.routers import auth, sessions, solve
|
| 30 |
+
from agents.ocr_agent import OCRAgent
|
| 31 |
+
from app.routers.solve import get_orchestrator
|
| 32 |
+
from app.supabase_client import get_supabase
|
| 33 |
+
from app.websocket_manager import register_websocket_routes
|
| 34 |
+
|
| 35 |
+
logger = logging.getLogger("app.main")
|
| 36 |
+
_access = logging.getLogger(ACCESS_LOGGER_NAME)
|
| 37 |
+
|
| 38 |
+
app = FastAPI(title="Visual Math Solver API v4.0")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@app.middleware("http")
|
| 42 |
+
async def access_log_middleware(request: Request, call_next):
|
| 43 |
+
"""LOG_LEVEL=info/debug: mọi request; warning: chỉ 4xx/5xx; error: chỉ 4xx/5xx ở mức error."""
|
| 44 |
+
start = time.perf_counter()
|
| 45 |
+
response = await call_next(request)
|
| 46 |
+
ms = (time.perf_counter() - start) * 1000
|
| 47 |
+
mode = get_log_level()
|
| 48 |
+
method = request.method
|
| 49 |
+
path = request.url.path
|
| 50 |
+
status = response.status_code
|
| 51 |
+
|
| 52 |
+
if mode in ("debug", "info"):
|
| 53 |
+
_access.info("%s %s -> %s (%.0fms)", method, path, status, ms)
|
| 54 |
+
elif mode == "warning":
|
| 55 |
+
if status >= 500:
|
| 56 |
+
_access.error("%s %s -> %s (%.0fms)", method, path, status, ms)
|
| 57 |
+
elif status >= 400:
|
| 58 |
+
_access.warning("%s %s -> %s (%.0fms)", method, path, status, ms)
|
| 59 |
+
elif mode == "error":
|
| 60 |
+
if status >= 400:
|
| 61 |
+
_access.error("%s %s -> %s", method, path, status)
|
| 62 |
+
|
| 63 |
+
return response
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
from worker.celery_app import BROKER_URL
|
| 67 |
+
|
| 68 |
+
_broker_tail = BROKER_URL.split("@")[-1] if "@" in BROKER_URL else BROKER_URL
|
| 69 |
+
if get_log_level() in ("debug", "info"):
|
| 70 |
+
logger.info("App starting LOG_LEVEL=%s | Redis: %s", get_log_level(), _broker_tail)
|
| 71 |
+
else:
|
| 72 |
+
logger.warning(
|
| 73 |
+
"App starting LOG_LEVEL=%s | Redis: %s", get_log_level(), _broker_tail
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
app.add_middleware(
|
| 77 |
+
CORSMiddleware,
|
| 78 |
+
allow_origins=["*"],
|
| 79 |
+
allow_credentials=True,
|
| 80 |
+
allow_methods=["*"],
|
| 81 |
+
allow_headers=["*"],
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
app.include_router(auth.router)
|
| 85 |
+
app.include_router(sessions.router)
|
| 86 |
+
app.include_router(solve.router)
|
| 87 |
+
|
| 88 |
+
register_websocket_routes(app)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_ocr_agent() -> OCRAgent:
|
| 92 |
+
"""Same OCR instance as the solve pipeline (no duplicate model load)."""
|
| 93 |
+
return get_orchestrator().ocr_agent
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
supabase_client = get_supabase()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@app.get("/")
|
| 100 |
+
def read_root():
|
| 101 |
+
return {"message": "Visual Math Solver API v4.0 is running"}
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@app.post("/api/v1/ocr")
|
| 105 |
+
async def upload_ocr(file: UploadFile = File(...)):
|
| 106 |
+
"""Legacy OCR endpoint (retained for now as it's stateless)"""
|
| 107 |
+
temp_path = f"temp_{uuid.uuid4()}.png"
|
| 108 |
+
with open(temp_path, "wb") as buffer:
|
| 109 |
+
buffer.write(await file.read())
|
| 110 |
+
|
| 111 |
+
try:
|
| 112 |
+
text = await get_ocr_agent().process_image(temp_path)
|
| 113 |
+
return {"text": text}
|
| 114 |
+
finally:
|
| 115 |
+
if os.path.exists(temp_path):
|
| 116 |
+
os.remove(temp_path)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@app.get("/api/v1/solve/{job_id}")
|
| 120 |
+
async def get_job_status(job_id: str):
|
| 121 |
+
"""Retrieve job status (can be used for polling if WS fails)"""
|
| 122 |
+
response = supabase_client.table("jobs").select("*").eq("id", job_id).execute()
|
| 123 |
+
if not response.data:
|
| 124 |
+
raise HTTPException(status_code=404, detail="Job not found")
|
| 125 |
+
return response.data[0]
|
app/models/schemas.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, EmailStr, field_validator
|
| 2 |
+
from typing import Optional, List, Any, Dict
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
import uuid
|
| 5 |
+
|
| 6 |
+
from app.url_utils import sanitize_url
|
| 7 |
+
|
| 8 |
+
# --- Auth Schemas ---
|
| 9 |
+
class UserProfile(BaseModel):
|
| 10 |
+
id: uuid.UUID
|
| 11 |
+
display_name: Optional[str] = None
|
| 12 |
+
avatar_url: Optional[str] = None
|
| 13 |
+
created_at: datetime
|
| 14 |
+
|
| 15 |
+
class User(BaseModel):
|
| 16 |
+
id: uuid.UUID
|
| 17 |
+
email: EmailStr
|
| 18 |
+
|
| 19 |
+
# --- Session Schemas ---
|
| 20 |
+
class SessionBase(BaseModel):
|
| 21 |
+
title: str = "Bài toán mới"
|
| 22 |
+
|
| 23 |
+
class SessionCreate(SessionBase):
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
class Session(SessionBase):
|
| 27 |
+
id: uuid.UUID
|
| 28 |
+
user_id: uuid.UUID
|
| 29 |
+
created_at: datetime
|
| 30 |
+
updated_at: datetime
|
| 31 |
+
|
| 32 |
+
class Config:
|
| 33 |
+
from_attributes = True
|
| 34 |
+
|
| 35 |
+
# --- Message Schemas ---
|
| 36 |
+
class MessageBase(BaseModel):
|
| 37 |
+
role: str
|
| 38 |
+
type: str = "text"
|
| 39 |
+
content: str
|
| 40 |
+
metadata: Dict[str, Any] = {}
|
| 41 |
+
|
| 42 |
+
class MessageCreate(MessageBase):
|
| 43 |
+
session_id: uuid.UUID
|
| 44 |
+
|
| 45 |
+
class Message(MessageBase):
|
| 46 |
+
id: uuid.UUID
|
| 47 |
+
session_id: uuid.UUID
|
| 48 |
+
created_at: datetime
|
| 49 |
+
|
| 50 |
+
class Config:
|
| 51 |
+
from_attributes = True
|
| 52 |
+
|
| 53 |
+
# --- Solve Job Schemas ---
|
| 54 |
+
class SolveRequest(BaseModel):
|
| 55 |
+
text: str
|
| 56 |
+
image_url: Optional[str] = None
|
| 57 |
+
request_video: bool = False
|
| 58 |
+
|
| 59 |
+
@field_validator("image_url", mode="before")
|
| 60 |
+
@classmethod
|
| 61 |
+
def _clean_image_url(cls, v):
|
| 62 |
+
return sanitize_url(v) if v is not None else None
|
| 63 |
+
|
| 64 |
+
class SolveResponse(BaseModel):
|
| 65 |
+
job_id: str
|
| 66 |
+
status: str
|
app/routers/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from . import auth, sessions, solve
|
app/routers/auth.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 2 |
+
from app.dependencies import get_current_user_id
|
| 3 |
+
from app.supabase_client import get_supabase
|
| 4 |
+
from app.models.schemas import UserProfile
|
| 5 |
+
import uuid
|
| 6 |
+
|
| 7 |
+
router = APIRouter(prefix="/api/v1/auth", tags=["Auth"])
|
| 8 |
+
|
| 9 |
+
@router.get("/me")
|
| 10 |
+
async def get_me(user_id=Depends(get_current_user_id)):
|
| 11 |
+
"""获取当前登录用户的信息 (Retrieve current user profile)"""
|
| 12 |
+
supabase = get_supabase()
|
| 13 |
+
res = supabase.table("profiles").select("*").eq("id", user_id).execute()
|
| 14 |
+
if not res.data:
|
| 15 |
+
raise HTTPException(status_code=404, detail="Profile not found.")
|
| 16 |
+
return res.data[0]
|
| 17 |
+
|
| 18 |
+
@router.patch("/me")
|
| 19 |
+
async def update_me(data: dict, user_id=Depends(get_current_user_id)):
|
| 20 |
+
"""Cập nhật profile hiện tại (Update current profile)"""
|
| 21 |
+
supabase = get_supabase()
|
| 22 |
+
res = supabase.table("profiles").update(data).eq("id", user_id).execute()
|
| 23 |
+
return res.data[0]
|
app/routers/sessions.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
from fastapi import APIRouter, Depends, HTTPException
|
| 6 |
+
|
| 7 |
+
from app.dependencies import get_current_user_id
|
| 8 |
+
from app.logutil import log_step
|
| 9 |
+
from app.session_cache import (
|
| 10 |
+
get_sessions_list_cached,
|
| 11 |
+
invalidate_for_user,
|
| 12 |
+
invalidate_session_owner,
|
| 13 |
+
session_owned_by_user,
|
| 14 |
+
)
|
| 15 |
+
from app.supabase_client import get_supabase
|
| 16 |
+
|
| 17 |
+
router = APIRouter(prefix="/api/v1/sessions", tags=["Sessions"])
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@router.get("", response_model=List[dict])
|
| 21 |
+
async def list_sessions(user_id=Depends(get_current_user_id)):
|
| 22 |
+
"""Danh sách các phiên chat của người dùng (List user's chat sessions)"""
|
| 23 |
+
supabase = get_supabase()
|
| 24 |
+
|
| 25 |
+
def fetch() -> list:
|
| 26 |
+
res = (
|
| 27 |
+
supabase.table("sessions")
|
| 28 |
+
.select("*")
|
| 29 |
+
.eq("user_id", user_id)
|
| 30 |
+
.order("updated_at", desc=True)
|
| 31 |
+
.execute()
|
| 32 |
+
)
|
| 33 |
+
log_step("db_select", table="sessions", op="list", user_id=str(user_id))
|
| 34 |
+
return res.data
|
| 35 |
+
|
| 36 |
+
return get_sessions_list_cached(str(user_id), fetch)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@router.post("", response_model=dict)
|
| 40 |
+
async def create_session(user_id=Depends(get_current_user_id)):
|
| 41 |
+
"""Tạo một phiên chat mới (Create a new chat session)"""
|
| 42 |
+
supabase = get_supabase()
|
| 43 |
+
res = supabase.table("sessions").insert(
|
| 44 |
+
{"user_id": user_id, "title": "Bài toán mới"}
|
| 45 |
+
).execute()
|
| 46 |
+
log_step("db_insert", table="sessions", op="create")
|
| 47 |
+
invalidate_for_user(str(user_id))
|
| 48 |
+
return res.data[0]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@router.get("/{session_id}/messages", response_model=List[dict])
|
| 52 |
+
async def get_session_messages(session_id: str, user_id=Depends(get_current_user_id)):
|
| 53 |
+
"""Lấy toàn bộ lịch sử tin nhắn của một phiên (Get chat history for a session)"""
|
| 54 |
+
supabase = get_supabase()
|
| 55 |
+
|
| 56 |
+
def owns() -> bool:
|
| 57 |
+
res = (
|
| 58 |
+
supabase.table("sessions")
|
| 59 |
+
.select("id")
|
| 60 |
+
.eq("id", session_id)
|
| 61 |
+
.eq("user_id", user_id)
|
| 62 |
+
.execute()
|
| 63 |
+
)
|
| 64 |
+
log_step("db_select", table="sessions", op="owner_check", session_id=session_id)
|
| 65 |
+
return bool(res.data)
|
| 66 |
+
|
| 67 |
+
if not session_owned_by_user(session_id, str(user_id), owns):
|
| 68 |
+
raise HTTPException(
|
| 69 |
+
status_code=403, detail="Forbidden: You do not own this session."
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
res = (
|
| 73 |
+
supabase.table("messages")
|
| 74 |
+
.select("*")
|
| 75 |
+
.eq("session_id", session_id)
|
| 76 |
+
.order("created_at", desc=False)
|
| 77 |
+
.execute()
|
| 78 |
+
)
|
| 79 |
+
log_step("db_select", table="messages", op="list", session_id=session_id)
|
| 80 |
+
return res.data
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@router.delete("/{session_id}")
|
| 84 |
+
async def delete_session(session_id: str, user_id=Depends(get_current_user_id)):
|
| 85 |
+
"""Xóa một phiên chat (Delete a chat session)"""
|
| 86 |
+
supabase = get_supabase()
|
| 87 |
+
|
| 88 |
+
def owns() -> bool:
|
| 89 |
+
res = (
|
| 90 |
+
supabase.table("sessions")
|
| 91 |
+
.select("id")
|
| 92 |
+
.eq("id", session_id)
|
| 93 |
+
.eq("user_id", user_id)
|
| 94 |
+
.execute()
|
| 95 |
+
)
|
| 96 |
+
return bool(res.data)
|
| 97 |
+
|
| 98 |
+
if not session_owned_by_user(session_id, str(user_id), owns):
|
| 99 |
+
raise HTTPException(
|
| 100 |
+
status_code=403, detail="Forbidden: You do not own this session."
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# jobs.session_id FK must be cleared before sessions row
|
| 104 |
+
supabase.table("jobs").delete().eq("session_id", session_id).eq("user_id", user_id).execute()
|
| 105 |
+
log_step("db_delete", table="jobs", op="by_session", session_id=session_id)
|
| 106 |
+
supabase.table("messages").delete().eq("session_id", session_id).execute()
|
| 107 |
+
log_step("db_delete", table="messages", op="by_session", session_id=session_id)
|
| 108 |
+
res = (
|
| 109 |
+
supabase.table("sessions")
|
| 110 |
+
.delete()
|
| 111 |
+
.eq("id", session_id)
|
| 112 |
+
.eq("user_id", user_id)
|
| 113 |
+
.execute()
|
| 114 |
+
)
|
| 115 |
+
log_step("db_delete", table="sessions", session_id=session_id)
|
| 116 |
+
invalidate_for_user(str(user_id))
|
| 117 |
+
invalidate_session_owner(session_id, str(user_id))
|
| 118 |
+
return {"status": "ok", "deleted_id": session_id}
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@router.patch("/{session_id}/title")
|
| 122 |
+
async def update_session_title(title: str, session_id: str, user_id=Depends(get_current_user_id)):
|
| 123 |
+
"""Cập nhật tiêu đề phiên chat (Rename a chat session)"""
|
| 124 |
+
supabase = get_supabase()
|
| 125 |
+
res = (
|
| 126 |
+
supabase.table("sessions")
|
| 127 |
+
.update({"title": title})
|
| 128 |
+
.eq("id", session_id)
|
| 129 |
+
.eq("user_id", user_id)
|
| 130 |
+
.execute()
|
| 131 |
+
)
|
| 132 |
+
log_step("db_update", table="sessions", op="title", session_id=session_id)
|
| 133 |
+
invalidate_for_user(str(user_id))
|
| 134 |
+
return res.data[0]
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@router.get("/{session_id}/assets", response_model=List[dict])
|
| 138 |
+
async def get_session_assets(session_id: str, user_id=Depends(get_current_user_id)):
|
| 139 |
+
"""Lấy danh sách video đã render trong session (Get versioned assets for a session)"""
|
| 140 |
+
supabase = get_supabase()
|
| 141 |
+
|
| 142 |
+
def owns() -> bool:
|
| 143 |
+
res = (
|
| 144 |
+
supabase.table("sessions")
|
| 145 |
+
.select("id")
|
| 146 |
+
.eq("id", session_id)
|
| 147 |
+
.eq("user_id", user_id)
|
| 148 |
+
.execute()
|
| 149 |
+
)
|
| 150 |
+
return bool(res.data)
|
| 151 |
+
|
| 152 |
+
if not session_owned_by_user(session_id, str(user_id), owns):
|
| 153 |
+
raise HTTPException(
|
| 154 |
+
status_code=403, detail="Forbidden: You do not own this session."
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
res = (
|
| 158 |
+
supabase.table("session_assets")
|
| 159 |
+
.select("*")
|
| 160 |
+
.eq("session_id", session_id)
|
| 161 |
+
.order("version", desc=True)
|
| 162 |
+
.execute()
|
| 163 |
+
)
|
| 164 |
+
log_step("db_select", table="session_assets", op="list", session_id=session_id)
|
| 165 |
+
return res.data
|
app/routers/solve.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import uuid
|
| 5 |
+
|
| 6 |
+
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
| 7 |
+
|
| 8 |
+
from agents.orchestrator import Orchestrator
|
| 9 |
+
from app.dependencies import get_current_user_id
|
| 10 |
+
from app.errors import format_error_for_user
|
| 11 |
+
from app.logutil import log_pipeline_failure, log_pipeline_success, log_step
|
| 12 |
+
from app.models.schemas import SolveRequest, SolveResponse
|
| 13 |
+
from app.session_cache import invalidate_for_user, session_owned_by_user
|
| 14 |
+
from app.supabase_client import get_supabase
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
router = APIRouter(prefix="/api/v1/sessions", tags=["Solve"])
|
| 18 |
+
|
| 19 |
+
# Eager init: all agents and models load at import time (also run in Docker build via scripts/prewarm_models.py).
|
| 20 |
+
ORCHESTRATOR = Orchestrator()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_orchestrator() -> Orchestrator:
|
| 24 |
+
return ORCHESTRATOR
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@router.post("/{session_id}/solve", response_model=SolveResponse)
|
| 28 |
+
async def solve_problem(
|
| 29 |
+
session_id: str,
|
| 30 |
+
request: SolveRequest,
|
| 31 |
+
background_tasks: BackgroundTasks,
|
| 32 |
+
user_id=Depends(get_current_user_id),
|
| 33 |
+
):
|
| 34 |
+
"""
|
| 35 |
+
Gửi câu hỏi giải toán trong một session (Submit geometry problem in a session).
|
| 36 |
+
Lưu câu hỏi vào history và bắt đầu tiến trình giải.
|
| 37 |
+
"""
|
| 38 |
+
supabase = get_supabase()
|
| 39 |
+
uid = str(user_id)
|
| 40 |
+
|
| 41 |
+
def owns() -> bool:
|
| 42 |
+
res = (
|
| 43 |
+
supabase.table("sessions")
|
| 44 |
+
.select("id")
|
| 45 |
+
.eq("id", session_id)
|
| 46 |
+
.eq("user_id", user_id)
|
| 47 |
+
.execute()
|
| 48 |
+
)
|
| 49 |
+
log_step("db_select", table="sessions", op="owner_check", session_id=session_id)
|
| 50 |
+
return bool(res.data)
|
| 51 |
+
|
| 52 |
+
if not session_owned_by_user(session_id, uid, owns):
|
| 53 |
+
log_pipeline_failure("solve_request", error="forbidden", session_id=session_id)
|
| 54 |
+
raise HTTPException(
|
| 55 |
+
status_code=403, detail="Forbidden: You do not own this session."
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# NEW: Giới hạn 5 queries mỗi session
|
| 59 |
+
msg_count_res = (
|
| 60 |
+
supabase.table("messages")
|
| 61 |
+
.select("id", count="exact")
|
| 62 |
+
.eq("session_id", session_id)
|
| 63 |
+
.eq("role", "user")
|
| 64 |
+
.execute()
|
| 65 |
+
)
|
| 66 |
+
current_count = msg_count_res.count if msg_count_res.count is not None else 0
|
| 67 |
+
import os
|
| 68 |
+
if current_count >= 5 and os.getenv("ALLOW_TEST_BYPASS") != "true":
|
| 69 |
+
raise HTTPException(
|
| 70 |
+
status_code=400,
|
| 71 |
+
detail="Bạn đã đạt giới hạn 5 câu hỏi cho phiên này. (Session limit reached: 5/5)"
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
supabase.table("messages").insert(
|
| 75 |
+
{
|
| 76 |
+
"session_id": session_id,
|
| 77 |
+
"role": "user",
|
| 78 |
+
"type": "text",
|
| 79 |
+
"content": request.text,
|
| 80 |
+
"metadata": {"image_url": request.image_url} if request.image_url else {},
|
| 81 |
+
}
|
| 82 |
+
).execute()
|
| 83 |
+
log_step("db_insert", table="messages", op="user_message", session_id=session_id)
|
| 84 |
+
|
| 85 |
+
job_id = str(uuid.uuid4())
|
| 86 |
+
supabase.table("jobs").insert(
|
| 87 |
+
{
|
| 88 |
+
"id": job_id,
|
| 89 |
+
"user_id": user_id,
|
| 90 |
+
"session_id": session_id,
|
| 91 |
+
"status": "processing",
|
| 92 |
+
"input_text": request.text,
|
| 93 |
+
}
|
| 94 |
+
).execute()
|
| 95 |
+
log_step("db_insert", table="jobs", job_id=job_id)
|
| 96 |
+
|
| 97 |
+
background_tasks.add_task(process_session_job, job_id, session_id, request, user_id)
|
| 98 |
+
|
| 99 |
+
title_check = supabase.table("sessions").select("title").eq("id", session_id).execute()
|
| 100 |
+
if title_check.data and title_check.data[0]["title"] == "Bài toán mới":
|
| 101 |
+
new_title = request.text[:50] + ("..." if len(request.text) > 50 else "")
|
| 102 |
+
supabase.table("sessions").update({"title": new_title}).eq("id", session_id).execute()
|
| 103 |
+
log_step("db_update", table="sessions", op="title_from_first_message")
|
| 104 |
+
invalidate_for_user(uid)
|
| 105 |
+
|
| 106 |
+
log_pipeline_success("solve_accepted", job_id=job_id, session_id=session_id)
|
| 107 |
+
return SolveResponse(job_id=job_id, status="processing")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
async def process_session_job(
|
| 111 |
+
job_id: str, session_id: str, request: SolveRequest, user_id: str
|
| 112 |
+
):
|
| 113 |
+
"""Tiến trình giải toán ngầm, cập nhật cả bảng jobs và bảng messages (history)."""
|
| 114 |
+
from app.websocket_manager import notify_status
|
| 115 |
+
|
| 116 |
+
async def status_update(status: str):
|
| 117 |
+
await notify_status(job_id, {"status": status})
|
| 118 |
+
|
| 119 |
+
supabase = get_supabase()
|
| 120 |
+
try:
|
| 121 |
+
# Fetch full history for the session
|
| 122 |
+
history_res = (
|
| 123 |
+
supabase.table("messages")
|
| 124 |
+
.select("*")
|
| 125 |
+
.eq("session_id", session_id)
|
| 126 |
+
.order("created_at", desc=False)
|
| 127 |
+
.execute()
|
| 128 |
+
)
|
| 129 |
+
history = history_res.data if history_res.data else []
|
| 130 |
+
|
| 131 |
+
result = await get_orchestrator().run(
|
| 132 |
+
request.text,
|
| 133 |
+
request.image_url,
|
| 134 |
+
job_id=job_id,
|
| 135 |
+
session_id=session_id,
|
| 136 |
+
status_callback=status_update,
|
| 137 |
+
request_video=request.request_video,
|
| 138 |
+
history=history,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
status = result.get("status", "error") if "error" not in result else "error"
|
| 142 |
+
|
| 143 |
+
supabase.table("jobs").update({"status": status, "result": result}).eq(
|
| 144 |
+
"id", job_id
|
| 145 |
+
).execute()
|
| 146 |
+
log_step("db_update", table="jobs", job_id=job_id, status=status)
|
| 147 |
+
|
| 148 |
+
if status != "rendering_queued":
|
| 149 |
+
supabase.table("messages").insert(
|
| 150 |
+
{
|
| 151 |
+
"session_id": session_id,
|
| 152 |
+
"role": "assistant",
|
| 153 |
+
"type": "analysis" if "error" not in result else "error",
|
| 154 |
+
"content": (
|
| 155 |
+
result.get("semantic_analysis", "Đã có lỗi xảy ra.")
|
| 156 |
+
if "error" not in result
|
| 157 |
+
else result["error"]
|
| 158 |
+
),
|
| 159 |
+
"metadata": {
|
| 160 |
+
"job_id": job_id,
|
| 161 |
+
"coordinates": result.get("coordinates"),
|
| 162 |
+
"geometry_dsl": result.get("geometry_dsl"),
|
| 163 |
+
"polygon_order": result.get("polygon_order", []),
|
| 164 |
+
"drawing_phases": result.get("drawing_phases", []),
|
| 165 |
+
"circles": result.get("circles", []),
|
| 166 |
+
"lines": result.get("lines", []),
|
| 167 |
+
"rays": result.get("rays", []),
|
| 168 |
+
"video_url": result.get("video_url"),
|
| 169 |
+
"solution": result.get("solution"),
|
| 170 |
+
},
|
| 171 |
+
}
|
| 172 |
+
).execute()
|
| 173 |
+
log_step("db_insert", table="messages", op="assistant", job_id=job_id)
|
| 174 |
+
|
| 175 |
+
await notify_status(job_id, {"status": status, "result": result})
|
| 176 |
+
|
| 177 |
+
if "error" in result:
|
| 178 |
+
log_pipeline_failure(
|
| 179 |
+
"solve_job", error=result.get("error"), job_id=job_id, session_id=session_id
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
log_pipeline_success(
|
| 183 |
+
"solve_job", job_id=job_id, session_id=session_id, status=status
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
except Exception as e:
|
| 187 |
+
logger.exception("Error processing session job %s", job_id)
|
| 188 |
+
safe = format_error_for_user(e)
|
| 189 |
+
supabase.table("jobs").update(
|
| 190 |
+
{"status": "error", "result": {"error": safe}}
|
| 191 |
+
).eq("id", job_id).execute()
|
| 192 |
+
|
| 193 |
+
supabase.table("messages").insert(
|
| 194 |
+
{
|
| 195 |
+
"session_id": session_id,
|
| 196 |
+
"role": "assistant",
|
| 197 |
+
"type": "error",
|
| 198 |
+
"content": f"Lỗi hệ thống: {safe}",
|
| 199 |
+
"metadata": {"job_id": job_id},
|
| 200 |
+
}
|
| 201 |
+
).execute()
|
| 202 |
+
|
| 203 |
+
await notify_status(job_id, {"status": "error", "message": safe})
|
| 204 |
+
log_pipeline_failure("solve_job", error=safe, job_id=job_id, session_id=session_id)
|
app/runtime_env.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Default process env vars (Paddle/OpenMP). Call as early as possible after load_dotenv."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def apply_runtime_env_defaults() -> None:
|
| 9 |
+
# Paddle respects OMP_NUM_THREADS at import; setdefault loses if platform already set 2+
|
| 10 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
| 11 |
+
os.environ["MKL_NUM_THREADS"] = "1"
|
| 12 |
+
os.environ["OPENBLAS_NUM_THREADS"] = "1"
|
app/session_cache.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""TTL in-memory cache để giảm truy vấn Supabase lặp lại (list session, quyền sở hữu session)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Callable
|
| 6 |
+
|
| 7 |
+
from cachetools import TTLCache
|
| 8 |
+
|
| 9 |
+
from app.logutil import log_step
|
| 10 |
+
|
| 11 |
+
_session_list: TTLCache[str, list[Any]] = TTLCache(maxsize=512, ttl=45)
|
| 12 |
+
_session_owner: TTLCache[tuple[str, str], bool] = TTLCache(maxsize=4096, ttl=45)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def invalidate_for_user(user_id: str) -> None:
|
| 16 |
+
"""Xoá cache list session của user (sau create / delete / rename / solve đổi title)."""
|
| 17 |
+
_session_list.pop(user_id, None)
|
| 18 |
+
log_step("cache_invalidate", target="session_list", user_id=user_id)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def invalidate_session_owner(session_id: str, user_id: str) -> None:
|
| 22 |
+
_session_owner.pop((session_id, user_id), None)
|
| 23 |
+
log_step("cache_invalidate", target="session_owner", session_id=session_id, user_id=user_id)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_sessions_list_cached(user_id: str, fetch: Callable[[], list[Any]]) -> list[Any]:
|
| 27 |
+
if user_id in _session_list:
|
| 28 |
+
log_step("cache_hit", kind="session_list", user_id=user_id)
|
| 29 |
+
return _session_list[user_id]
|
| 30 |
+
log_step("cache_miss", kind="session_list", user_id=user_id)
|
| 31 |
+
data = fetch()
|
| 32 |
+
_session_list[user_id] = data
|
| 33 |
+
return data
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def session_owned_by_user(
|
| 37 |
+
session_id: str,
|
| 38 |
+
user_id: str,
|
| 39 |
+
fetch: Callable[[], bool],
|
| 40 |
+
) -> bool:
|
| 41 |
+
key = (session_id, user_id)
|
| 42 |
+
if key in _session_owner:
|
| 43 |
+
log_step("cache_hit", kind="session_owner", session_id=session_id)
|
| 44 |
+
return _session_owner[key]
|
| 45 |
+
log_step("cache_miss", kind="session_owner", session_id=session_id)
|
| 46 |
+
ok = fetch()
|
| 47 |
+
_session_owner[key] = ok
|
| 48 |
+
return ok
|
app/supabase_client.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from supabase import Client, ClientOptions, create_client
|
| 3 |
+
from supabase_auth import SyncMemoryStorage
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
|
| 6 |
+
load_dotenv()
|
| 7 |
+
|
| 8 |
+
from app.url_utils import sanitize_env
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_supabase() -> Client:
|
| 12 |
+
"""Service-role client for server-side operations (bypasses RLS when policies expect service role)."""
|
| 13 |
+
url = sanitize_env(os.getenv("SUPABASE_URL"))
|
| 14 |
+
key = sanitize_env(os.getenv("SUPABASE_SERVICE_ROLE_KEY") or os.getenv("SUPABASE_KEY"))
|
| 15 |
+
if not url or not key:
|
| 16 |
+
raise RuntimeError(
|
| 17 |
+
"SUPABASE_URL and SUPABASE_SERVICE_ROLE_KEY (or SUPABASE_KEY) must be set"
|
| 18 |
+
)
|
| 19 |
+
return create_client(url, key)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_supabase_for_user_jwt(access_token: str) -> Client:
|
| 23 |
+
"""
|
| 24 |
+
Client scoped to the logged-in user: PostgREST sends the user's JWT so RLS applies.
|
| 25 |
+
Use SUPABASE_ANON_KEY (publishable), not the service role key.
|
| 26 |
+
"""
|
| 27 |
+
url = sanitize_env(os.getenv("SUPABASE_URL"))
|
| 28 |
+
anon = sanitize_env(os.getenv("SUPABASE_ANON_KEY") or os.getenv("NEXT_PUBLIC_SUPABASE_ANON_KEY"))
|
| 29 |
+
if not url or not anon:
|
| 30 |
+
raise RuntimeError(
|
| 31 |
+
"SUPABASE_URL and SUPABASE_ANON_KEY (or NEXT_PUBLIC_SUPABASE_ANON_KEY) must be set "
|
| 32 |
+
"for user-scoped Supabase access"
|
| 33 |
+
)
|
| 34 |
+
base_opts = ClientOptions(storage=SyncMemoryStorage())
|
| 35 |
+
merged_headers = {**dict(base_opts.headers), "Authorization": f"Bearer {access_token}"}
|
| 36 |
+
opts = ClientOptions(storage=SyncMemoryStorage(), headers=merged_headers)
|
| 37 |
+
return create_client(url, anon, opts)
|
app/url_utils.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Normalize URLs / env strings (HF secrets and copy-paste often include trailing newlines)."""
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def sanitize_url(value: str | None) -> str | None:
|
| 5 |
+
if value is None:
|
| 6 |
+
return None
|
| 7 |
+
s = value.strip().replace("\r", "").replace("\n", "").replace("\t", "")
|
| 8 |
+
return s or None
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def sanitize_env(value: str | None) -> str | None:
|
| 12 |
+
"""Strip whitespace and line breaks from environment-backed strings."""
|
| 13 |
+
return sanitize_url(value)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# OpenAI SDK (>=1.x) requires a non-empty api_key at client construction (Docker build / prewarm has no secrets).
|
| 17 |
+
_OPENAI_API_KEY_BUILD_PLACEHOLDER = "build-placeholder-openrouter-not-for-production"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def openai_compatible_api_key(raw: str | None) -> str:
|
| 21 |
+
"""Return sanitized API key, or a placeholder so AsyncOpenAI() can be constructed without env at build time."""
|
| 22 |
+
k = sanitize_env(raw)
|
| 23 |
+
return k if k else _OPENAI_API_KEY_BUILD_PLACEHOLDER
|
app/websocket_manager.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""WebSocket connection registry and job status notifications (avoid circular imports with main)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from typing import Dict, List
|
| 7 |
+
|
| 8 |
+
from fastapi import WebSocket, WebSocketDisconnect
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
active_connections: Dict[str, List[WebSocket]] = {}
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
async def notify_status(job_id: str, data: dict) -> None:
|
| 16 |
+
if job_id not in active_connections:
|
| 17 |
+
return
|
| 18 |
+
for connection in list(active_connections[job_id]):
|
| 19 |
+
try:
|
| 20 |
+
await connection.send_json(data)
|
| 21 |
+
except Exception as e:
|
| 22 |
+
logger.error("WS error sending to %s: %s", job_id, e)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def register_websocket_routes(app) -> None:
|
| 26 |
+
"""Attach websocket endpoint to the FastAPI app."""
|
| 27 |
+
|
| 28 |
+
@app.websocket("/ws/{job_id}")
|
| 29 |
+
async def websocket_endpoint(websocket: WebSocket, job_id: str) -> None:
|
| 30 |
+
await websocket.accept()
|
| 31 |
+
if job_id not in active_connections:
|
| 32 |
+
active_connections[job_id] = []
|
| 33 |
+
active_connections[job_id].append(websocket)
|
| 34 |
+
try:
|
| 35 |
+
while True:
|
| 36 |
+
await websocket.receive_text()
|
| 37 |
+
except WebSocketDisconnect:
|
| 38 |
+
active_connections[job_id].remove(websocket)
|
| 39 |
+
if not active_connections[job_id]:
|
| 40 |
+
del active_connections[job_id]
|
clean_ports.sh
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Script to kill all project-related processes for a clean restart
|
| 3 |
+
|
| 4 |
+
echo "🧹 Cleaning up project processes..."
|
| 5 |
+
|
| 6 |
+
# Kill things on ports 8000 (Backend) and 3000 (Frontend)
|
| 7 |
+
PORTS="8000 3000 11020"
|
| 8 |
+
for PORT in $PORTS; do
|
| 9 |
+
PIDS=$(lsof -ti :$PORT)
|
| 10 |
+
if [ ! -z "$PIDS" ]; then
|
| 11 |
+
echo "Killing processes on port $PORT: $PIDS"
|
| 12 |
+
kill -9 $PIDS 2>/dev/null
|
| 13 |
+
fi
|
| 14 |
+
done
|
| 15 |
+
|
| 16 |
+
# Kill by process name
|
| 17 |
+
echo "Killing any remaining Celery, Uvicorn, or Manim processes..."
|
| 18 |
+
pkill -9 -f "celery" 2>/dev/null
|
| 19 |
+
pkill -9 -f "uvicorn" 2>/dev/null
|
| 20 |
+
pkill -9 -f "manim" 2>/dev/null
|
| 21 |
+
|
| 22 |
+
echo "✅ Done. You can now restart your Backend, Worker, and Frontend."
|
migrations/v4_migration.sql
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-- ============================================================
|
| 2 |
+
-- MATHSOLVER v4.0 - Migration Script (Multi-Session & History)
|
| 3 |
+
-- ============================================================
|
| 4 |
+
|
| 5 |
+
-- 1. Profiles Table (Extends Supabase Auth)
|
| 6 |
+
CREATE TABLE IF NOT EXISTS public.profiles (
|
| 7 |
+
id UUID PRIMARY KEY REFERENCES auth.users(id) ON DELETE CASCADE,
|
| 8 |
+
display_name TEXT,
|
| 9 |
+
avatar_url TEXT,
|
| 10 |
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
| 11 |
+
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
| 12 |
+
);
|
| 13 |
+
|
| 14 |
+
-- Function to handle new user signup and auto-create profile
|
| 15 |
+
CREATE OR REPLACE FUNCTION public.handle_new_user()
|
| 16 |
+
RETURNS TRIGGER AS $$
|
| 17 |
+
BEGIN
|
| 18 |
+
INSERT INTO public.profiles (id, display_name, avatar_url)
|
| 19 |
+
VALUES (
|
| 20 |
+
NEW.id,
|
| 21 |
+
COALESCE(NEW.raw_user_meta_data->>'full_name', NEW.email),
|
| 22 |
+
NEW.raw_user_meta_data->>'avatar_url'
|
| 23 |
+
);
|
| 24 |
+
RETURN NEW;
|
| 25 |
+
END;
|
| 26 |
+
$$ LANGUAGE plpgsql SECURITY DEFINER;
|
| 27 |
+
|
| 28 |
+
-- Trigger for profile creation
|
| 29 |
+
DROP TRIGGER IF EXISTS on_auth_user_created ON auth.users;
|
| 30 |
+
CREATE TRIGGER on_auth_user_created
|
| 31 |
+
AFTER INSERT ON auth.users
|
| 32 |
+
FOR EACH ROW EXECUTE FUNCTION public.handle_new_user();
|
| 33 |
+
|
| 34 |
+
-- 2. Sessions Table
|
| 35 |
+
CREATE TABLE IF NOT EXISTS public.sessions (
|
| 36 |
+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
| 37 |
+
user_id UUID NOT NULL REFERENCES auth.users(id) ON DELETE CASCADE,
|
| 38 |
+
title TEXT DEFAULT 'Bài toán mới',
|
| 39 |
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
| 40 |
+
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
| 41 |
+
);
|
| 42 |
+
|
| 43 |
+
-- Index for sessions
|
| 44 |
+
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON public.sessions(user_id);
|
| 45 |
+
CREATE INDEX IF NOT EXISTS idx_sessions_updated_at ON public.sessions(updated_at DESC);
|
| 46 |
+
|
| 47 |
+
-- 3. Messages Table
|
| 48 |
+
CREATE TABLE IF NOT EXISTS public.messages (
|
| 49 |
+
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
| 50 |
+
session_id UUID NOT NULL REFERENCES public.sessions(id) ON DELETE CASCADE,
|
| 51 |
+
role TEXT NOT NULL CHECK (role IN ('user', 'assistant', 'system')),
|
| 52 |
+
type TEXT NOT NULL DEFAULT 'text',
|
| 53 |
+
content TEXT NOT NULL,
|
| 54 |
+
metadata JSONB DEFAULT '{}'::jsonb,
|
| 55 |
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
| 56 |
+
);
|
| 57 |
+
|
| 58 |
+
-- Index for messages
|
| 59 |
+
CREATE INDEX IF NOT EXISTS idx_messages_session_id ON public.messages(session_id);
|
| 60 |
+
CREATE INDEX IF NOT EXISTS idx_messages_created_at ON public.messages(session_id, created_at);
|
| 61 |
+
|
| 62 |
+
-- 4. Update Jobs Table
|
| 63 |
+
ALTER TABLE public.jobs ADD COLUMN IF NOT EXISTS user_id UUID REFERENCES auth.users(id);
|
| 64 |
+
ALTER TABLE public.jobs ADD COLUMN IF NOT EXISTS session_id UUID REFERENCES public.sessions(id);
|
| 65 |
+
|
| 66 |
+
-- 5. Row Level Security (RLS)
|
| 67 |
+
ALTER TABLE public.profiles ENABLE ROW LEVEL SECURITY;
|
| 68 |
+
ALTER TABLE public.sessions ENABLE ROW LEVEL SECURITY;
|
| 69 |
+
ALTER TABLE public.messages ENABLE ROW LEVEL SECURITY;
|
| 70 |
+
|
| 71 |
+
-- Polices for public.profiles
|
| 72 |
+
DROP POLICY IF EXISTS "Users view own profile" ON public.profiles;
|
| 73 |
+
CREATE POLICY "Users view own profile" ON public.profiles FOR SELECT USING (auth.uid() = id);
|
| 74 |
+
DROP POLICY IF EXISTS "Users update own profile" ON public.profiles;
|
| 75 |
+
CREATE POLICY "Users update own profile" ON public.profiles FOR UPDATE USING (auth.uid() = id);
|
| 76 |
+
|
| 77 |
+
-- Policies for public.sessions
|
| 78 |
+
DROP POLICY IF EXISTS "Users manage own sessions" ON public.sessions;
|
| 79 |
+
CREATE POLICY "Users manage own sessions" ON public.sessions FOR ALL USING (auth.uid() = user_id);
|
| 80 |
+
|
| 81 |
+
-- Policies for public.messages
|
| 82 |
+
DROP POLICY IF EXISTS "Users view own messages" ON public.messages;
|
| 83 |
+
CREATE POLICY "Users view own messages" ON public.messages FOR ALL USING (
|
| 84 |
+
session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid())
|
| 85 |
+
);
|
| 86 |
+
|
| 87 |
+
-- Policies for public.jobs
|
| 88 |
+
DROP POLICY IF EXISTS "Users view own jobs" ON public.jobs;
|
| 89 |
+
CREATE POLICY "Users view own jobs" ON public.jobs FOR ALL USING (auth.uid() = user_id OR user_id IS NULL);
|
| 90 |
+
|
| 91 |
+
-- Grant permissions to public/authenticated
|
| 92 |
+
GRANT ALL ON public.profiles TO authenticated;
|
| 93 |
+
GRANT ALL ON public.sessions TO authenticated;
|
| 94 |
+
GRANT ALL ON public.messages TO authenticated;
|
| 95 |
+
GRANT ALL ON public.jobs TO authenticated;
|
requirements.txt
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Target: Python 3.11 (see Dockerfile). Used by: FastAPI API, Celery worker, Manim render, OCR/vision stack.
|
| 2 |
+
# Install: pip install -r requirements.txt
|
| 3 |
+
|
| 4 |
+
# --- HTTP API ---
|
| 5 |
+
cachetools>=5.3
|
| 6 |
+
fastapi>=0.115,<1
|
| 7 |
+
uvicorn[standard]>=0.30
|
| 8 |
+
python-multipart>=0.0.9
|
| 9 |
+
python-dotenv>=1.0
|
| 10 |
+
pydantic[email]>=2.4
|
| 11 |
+
email-validator>=2
|
| 12 |
+
|
| 13 |
+
# --- Auth / data / queue ---
|
| 14 |
+
openai>=1.40
|
| 15 |
+
supabase>=2.0
|
| 16 |
+
celery>=5.3
|
| 17 |
+
redis>=5
|
| 18 |
+
httpx>=0.27
|
| 19 |
+
websockets>=12
|
| 20 |
+
|
| 21 |
+
# --- Math & symbolic solver ---
|
| 22 |
+
sympy>=1.12
|
| 23 |
+
numpy>=1.26,<2
|
| 24 |
+
scipy>=1.11
|
| 25 |
+
opencv-python-headless>=4.8,<4.10
|
| 26 |
+
|
| 27 |
+
# --- Video (GeometryScene via CLI) ---
|
| 28 |
+
manim>=0.18,<0.20
|
| 29 |
+
|
| 30 |
+
# --- OCR & vision (orchestrator / legacy /ocr) ---
|
| 31 |
+
pix2tex>=0.1.4
|
| 32 |
+
paddleocr==2.7.3
|
| 33 |
+
paddlepaddle==2.6.2
|
| 34 |
+
ultralytics==8.2.2
|
run_api_test.sh
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
LOG_FILE="api_test_results.log"
|
| 4 |
+
echo "=== Starting API E2E Test Suite ($(date)) ===" > $LOG_FILE
|
| 5 |
+
|
| 6 |
+
# 1. Start BE Server in background
|
| 7 |
+
echo "[INFO] Starting Backend Server..." | tee -a $LOG_FILE
|
| 8 |
+
export ALLOW_TEST_BYPASS=true
|
| 9 |
+
export LOG_LEVEL=info
|
| 10 |
+
PYTHONPATH=. venv/bin/python -m uvicorn app.main:app --port 8000 > server_debug.log 2>&1 &
|
| 11 |
+
SERVER_PID=$!
|
| 12 |
+
|
| 13 |
+
# 2. Wait for server to be ready
|
| 14 |
+
echo "[INFO] Waiting for server (PID: $SERVER_PID) on port 8000..." | tee -a $LOG_FILE
|
| 15 |
+
MAX_RETRIES=15
|
| 16 |
+
READY=0
|
| 17 |
+
for i in $(seq 1 $MAX_RETRIES); do
|
| 18 |
+
if curl -s http://localhost:8000/ > /dev/null; then
|
| 19 |
+
READY=1
|
| 20 |
+
break
|
| 21 |
+
fi
|
| 22 |
+
sleep 2
|
| 23 |
+
done
|
| 24 |
+
|
| 25 |
+
if [ $READY -eq 0 ]; then
|
| 26 |
+
echo "[ERROR] Server failed to start in time. Check server_debug.log" | tee -a $LOG_FILE
|
| 27 |
+
kill $SERVER_PID
|
| 28 |
+
exit 1
|
| 29 |
+
fi
|
| 30 |
+
echo "[INFO] Server is READY." | tee -a $LOG_FILE
|
| 31 |
+
|
| 32 |
+
# 3. Prepare Test Data
|
| 33 |
+
echo "[INFO] Preparing fresh test data..." | tee -a $LOG_FILE
|
| 34 |
+
PREP_OUTPUT=$(PYTHONPATH=. venv/bin/python scripts/prepare_api_test.py)
|
| 35 |
+
echo "$PREP_OUTPUT" >> $LOG_FILE
|
| 36 |
+
|
| 37 |
+
export TEST_USER_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:USER_ID=" | cut -d'=' -f2)
|
| 38 |
+
export TEST_SESSION_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:SESSION_ID=" | cut -d'=' -f2)
|
| 39 |
+
|
| 40 |
+
if [ -z "$TEST_USER_ID" ] || [ -z "$TEST_SESSION_ID" ]; then
|
| 41 |
+
echo "[ERROR] Failed to prepare test data." | tee -a $LOG_FILE
|
| 42 |
+
kill $SERVER_PID
|
| 43 |
+
exit 1
|
| 44 |
+
fi
|
| 45 |
+
|
| 46 |
+
echo "[INFO] Test Data: User=$TEST_USER_ID, Session=$TEST_SESSION_ID" | tee -a $LOG_FILE
|
| 47 |
+
|
| 48 |
+
# 4. Run Pytest
|
| 49 |
+
echo "[INFO] Running API E2E Tests..." | tee -a $LOG_FILE
|
| 50 |
+
PYTHONPATH=. venv/bin/python -m pytest tests/test_api_real_e2e.py -s >> $LOG_FILE 2>&1
|
| 51 |
+
TEST_EXIT_CODE=$?
|
| 52 |
+
|
| 53 |
+
# 5. Cleanup
|
| 54 |
+
echo "[INFO] Shutting down Server..." | tee -a $LOG_FILE
|
| 55 |
+
kill $SERVER_PID
|
| 56 |
+
|
| 57 |
+
echo "==========================================" | tee -a $LOG_FILE
|
| 58 |
+
if [ $TEST_EXIT_CODE -eq 0 ]; then
|
| 59 |
+
echo "FINAL RESULT: ✅ ALL API TESTS PASSED" | tee -a $LOG_FILE
|
| 60 |
+
else
|
| 61 |
+
echo "FINAL RESULT: ❌ SOME API TESTS FAILED (Code: $TEST_EXIT_CODE)" | tee -a $LOG_FILE
|
| 62 |
+
fi
|
| 63 |
+
echo "==========================================" | tee -a $LOG_FILE
|
| 64 |
+
|
| 65 |
+
exit $TEST_EXIT_CODE
|
run_full_api_test.sh
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Configuration and Cleanup
|
| 4 |
+
LOG_FILE="full_api_suite.log"
|
| 5 |
+
REPORT_FILE="full_api_test_report.md"
|
| 6 |
+
JSON_RESULTS="temp_suite_results.json"
|
| 7 |
+
|
| 8 |
+
echo "=== Starting Full API Suite Test ($(date)) ===" > $LOG_FILE
|
| 9 |
+
|
| 10 |
+
# Cleanup on exit
|
| 11 |
+
trap 'echo "[INFO] Cleaning up processes..."; kill $SERVER_PID 2>/dev/null; sleep 1' EXIT
|
| 12 |
+
|
| 13 |
+
# 1. Start Server in EAGER MODE + MOCK VIDEO (no Redis/Worker needed)
|
| 14 |
+
echo "[INFO] Starting Backend Server (EAGER + MOCK_VIDEO)..." | tee -a $LOG_FILE
|
| 15 |
+
export ALLOW_TEST_BYPASS=true
|
| 16 |
+
export LOG_LEVEL=info
|
| 17 |
+
export CELERY_TASK_ALWAYS_EAGER=true
|
| 18 |
+
export CELERY_RESULT_BACKEND=rpc://
|
| 19 |
+
export MOCK_VIDEO=true
|
| 20 |
+
PYTHONPATH=. venv/bin/python -m uvicorn app.main:app --port 8000 > server_debug.log 2>&1 &
|
| 21 |
+
SERVER_PID=$!
|
| 22 |
+
|
| 23 |
+
# 2. Wait for server
|
| 24 |
+
echo "[INFO] Waiting for server (PID: $SERVER_PID)..." | tee -a $LOG_FILE
|
| 25 |
+
for i in {1..20}; do
|
| 26 |
+
if curl -s http://localhost:8000/ > /dev/null; then
|
| 27 |
+
echo "[INFO] Server is READY." | tee -a $LOG_FILE
|
| 28 |
+
break
|
| 29 |
+
fi
|
| 30 |
+
sleep 2
|
| 31 |
+
done
|
| 32 |
+
|
| 33 |
+
# 3. Prepare Test Data
|
| 34 |
+
echo "[INFO] Preparing fresh test data..." | tee -a $LOG_FILE
|
| 35 |
+
PREP_OUTPUT=$(PYTHONPATH=. venv/bin/python scripts/prepare_api_test.py)
|
| 36 |
+
export TEST_USER_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:USER_ID=" | cut -d'=' -f2)
|
| 37 |
+
export TEST_SESSION_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:SESSION_ID=" | cut -d'=' -f2)
|
| 38 |
+
|
| 39 |
+
if [ -z "$TEST_USER_ID" ]; then
|
| 40 |
+
echo "[ERROR] Failed to prepare test data." | tee -a $LOG_FILE
|
| 41 |
+
exit 1
|
| 42 |
+
fi
|
| 43 |
+
|
| 44 |
+
# 4. Run Pytest Suite
|
| 45 |
+
echo "[INFO] Executing Full API Suite..." | tee -a $LOG_FILE
|
| 46 |
+
PYTHONPATH=. venv/bin/python -m pytest tests/test_api_full_suite.py -s >> $LOG_FILE 2>&1
|
| 47 |
+
TEST_EXIT_CODE=$?
|
| 48 |
+
|
| 49 |
+
# 5. Shut down server
|
| 50 |
+
echo "[INFO] Shutting down processes..." | tee -a $LOG_FILE
|
| 51 |
+
|
| 52 |
+
# 6. Generate Markdown Report
|
| 53 |
+
echo "[INFO] Generating Markdown Report..." | tee -a $LOG_FILE
|
| 54 |
+
PYTHONPATH=. venv/bin/python scripts/generate_report.py "$JSON_RESULTS" "$REPORT_FILE"
|
| 55 |
+
|
| 56 |
+
echo "==========================================" | tee -a $LOG_FILE
|
| 57 |
+
echo "DONE. Check $REPORT_FILE for results." | tee -a $LOG_FILE
|
| 58 |
+
echo "==========================================" | tee -a $LOG_FILE
|
| 59 |
+
|
| 60 |
+
exit $TEST_EXIT_CODE
|
scripts/backend_test_suite.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import time
|
| 3 |
+
import json
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
BASE_URL = "http://localhost:8000/api/v1"
|
| 7 |
+
|
| 8 |
+
TEST_CASES = [
|
| 9 |
+
{
|
| 10 |
+
"name": "Equilateral Triangle",
|
| 11 |
+
"payload": {"text": "Vẽ tam giác đều cạnh 5.", "request_video": True}
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"name": "Right Triangle (3-4-5)",
|
| 15 |
+
"payload": {"text": "Cho tam giác ABC vuông tại A có AB=3, AC=4. Tính BC.", "request_video": True}
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"name": "Isosceles Triangle",
|
| 19 |
+
"payload": {"text": "Cho tam giác ABC cân tại A có AB=5, BC=6.", "request_video": False}
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"name": "Square",
|
| 23 |
+
"payload": {"text": "Vẽ hình vuông ABCD cạnh 4.", "request_video": True}
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"name": "Invalid Input",
|
| 27 |
+
"payload": {"text": "abcxyz", "request_video": False}
|
| 28 |
+
}
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
def run_test(test_case):
|
| 32 |
+
print(f"\n[TEST] Running: {test_case['name']}...")
|
| 33 |
+
try:
|
| 34 |
+
start_time = time.time()
|
| 35 |
+
# Create job
|
| 36 |
+
response = requests.post(f"{BASE_URL}/solve", json=test_case['payload'])
|
| 37 |
+
if response.status_code != 200:
|
| 38 |
+
print(f" [FAIL] Initial request failed: {response.text}")
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
job_id = response.json().get("job_id")
|
| 42 |
+
print(f" [INFO] Job ID: {job_id}")
|
| 43 |
+
|
| 44 |
+
# Poll for completion
|
| 45 |
+
status = "processing"
|
| 46 |
+
max_attempts = 40
|
| 47 |
+
attempts = 0
|
| 48 |
+
while status in ["processing", "solving", "rendering_queued", "rendering"] and attempts < max_attempts:
|
| 49 |
+
time.sleep(5)
|
| 50 |
+
res = requests.get(f"{BASE_URL}/solve/{job_id}")
|
| 51 |
+
data = res.json()
|
| 52 |
+
status = data.get("status")
|
| 53 |
+
print(f" [INFO] Status: {status} (Attempt {attempts+1})")
|
| 54 |
+
if status == "success":
|
| 55 |
+
duration = time.time() - start_time
|
| 56 |
+
print(f" [SUCCESS] Completed in {duration:.2f}s")
|
| 57 |
+
if test_case['payload'].get('request_video'):
|
| 58 |
+
video_url = data.get("result", {}).get("video_url")
|
| 59 |
+
if video_url:
|
| 60 |
+
print(f" [INFO] Video URL: {video_url}")
|
| 61 |
+
else:
|
| 62 |
+
print(" [WARNING] Video requested but no URL found in result.")
|
| 63 |
+
return True
|
| 64 |
+
if status == "error":
|
| 65 |
+
print(f" [FAIL] Solver error: {data.get('result', {}).get('error')}")
|
| 66 |
+
return False
|
| 67 |
+
attempts += 1
|
| 68 |
+
|
| 69 |
+
if attempts >= max_attempts:
|
| 70 |
+
print(" [FAIL] Timeout reached.")
|
| 71 |
+
return False
|
| 72 |
+
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f" [ERROR] Exception: {str(e)}")
|
| 75 |
+
return False
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
results = []
|
| 79 |
+
print("=== MathSolver Backend Test Suite ===")
|
| 80 |
+
for tc in TEST_CASES:
|
| 81 |
+
success = run_test(tc)
|
| 82 |
+
results.append((tc['name'], success))
|
| 83 |
+
|
| 84 |
+
print("\n" + "="*40)
|
| 85 |
+
print("FINAL REPORT:")
|
| 86 |
+
all_passed = True
|
| 87 |
+
for name, success in results:
|
| 88 |
+
status_str = "PASS" if success else "FAIL"
|
| 89 |
+
print(f"- {name}: {status_str}")
|
| 90 |
+
if not success: all_passed = False
|
| 91 |
+
|
| 92 |
+
if all_passed:
|
| 93 |
+
print("\nALL TESTS PASSED! 🎉")
|
| 94 |
+
sys.exit(0)
|
| 95 |
+
else:
|
| 96 |
+
print("\nSOME TESTS FAILED. ❌")
|
| 97 |
+
sys.exit(1)
|
scripts/generate_report.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
def generate_report(json_path, report_path):
|
| 7 |
+
try:
|
| 8 |
+
with open(json_path, 'r') as f:
|
| 9 |
+
data = json.load(f)
|
| 10 |
+
|
| 11 |
+
with open(report_path, 'w') as f:
|
| 12 |
+
f.write('# Báo cáo Kiểm thử API Toàn diện (Full Suite API Report)\n\n')
|
| 13 |
+
f.write(f'**Thời gian chạy:** {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}\n')
|
| 14 |
+
f.write(f'**Kết quả chung:** {"✅ PASS" if all(r.get("success", False) for r in data) else "❌ FAIL"}\n\n')
|
| 15 |
+
|
| 16 |
+
f.write('| ID | Câu hỏi (Query) | Trạng thái | Thời gian (s) | Kết quả / Lỗi |\n')
|
| 17 |
+
f.write('| :--- | :--- | :--- | :--- | :--- |\n')
|
| 18 |
+
for r in data:
|
| 19 |
+
status = "✅ PASS" if r.get("success") else "❌ FAIL"
|
| 20 |
+
elapsed = f"{r.get('elapsed', 0):.2f}"
|
| 21 |
+
query = r.get('query', '-')
|
| 22 |
+
|
| 23 |
+
# Extract analysis or error
|
| 24 |
+
res = r.get('result', {})
|
| 25 |
+
if not isinstance(res, dict):
|
| 26 |
+
res = {}
|
| 27 |
+
|
| 28 |
+
analysis = res.get('semantic_analysis', '-')
|
| 29 |
+
if not r.get("success"):
|
| 30 |
+
analysis = f"**Lỗi:** {r.get('error', '-')}"
|
| 31 |
+
|
| 32 |
+
# Truncate long analysis for table
|
| 33 |
+
short_analysis = (analysis[:100] + '...') if len(analysis) > 100 else analysis
|
| 34 |
+
|
| 35 |
+
f.write(f'| {r["id"]} | {query} | {status} | {elapsed} | {short_analysis} |\n')
|
| 36 |
+
|
| 37 |
+
f.write('\n---\n**Chi tiết Output (DSL & Analysis):**\n')
|
| 38 |
+
for r in data:
|
| 39 |
+
if r.get('success'):
|
| 40 |
+
res = r.get('result', {})
|
| 41 |
+
if not isinstance(res, dict):
|
| 42 |
+
continue
|
| 43 |
+
|
| 44 |
+
f.write(f"\n### Case {r['id']}: {r.get('query')}\n")
|
| 45 |
+
f.write(f"**Semantic Analysis:**\n{res.get('semantic_analysis', '-')}\n\n")
|
| 46 |
+
f.write(f"**Geometry DSL:**\n```\n{res.get('geometry_dsl', '-')}\n```\n")
|
| 47 |
+
|
| 48 |
+
# v5.1 Solution info
|
| 49 |
+
sol = res.get('solution')
|
| 50 |
+
if sol and isinstance(sol, dict):
|
| 51 |
+
f.write("**Solution (v5.1):**\n")
|
| 52 |
+
f.write(f"- **Answer:** {sol.get('answer', 'N/A')}\n")
|
| 53 |
+
f.write("- **Steps:**\n")
|
| 54 |
+
steps = sol.get('steps', [])
|
| 55 |
+
if steps:
|
| 56 |
+
for step in steps:
|
| 57 |
+
f.write(f" - {step}\n")
|
| 58 |
+
else:
|
| 59 |
+
f.write(" - (Không có bước giải cụ thể)\n")
|
| 60 |
+
|
| 61 |
+
if sol.get('symbolic_expression'):
|
| 62 |
+
f.write(f"- **Symbolic:** `{sol.get('symbolic_expression')}`\n")
|
| 63 |
+
f.write("\n")
|
| 64 |
+
|
| 65 |
+
print(f'Report generated: {report_path}')
|
| 66 |
+
except Exception as e:
|
| 67 |
+
print(f'Error generating report: {e}')
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
if len(sys.argv) < 3:
|
| 71 |
+
print("Usage: python generate_report.py <json_results> <report_output>")
|
| 72 |
+
sys.exit(1)
|
| 73 |
+
generate_report(sys.argv[1], sys.argv[2])
|
scripts/prepare_api_test.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# Add parent dir to path to import app modules
|
| 6 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 7 |
+
|
| 8 |
+
from app.supabase_client import get_supabase
|
| 9 |
+
|
| 10 |
+
def prepare():
|
| 11 |
+
supabase = get_supabase()
|
| 12 |
+
# Use existing valid user to avoid foreign key violation on sessions.user_id
|
| 13 |
+
user_id = "8cd3adb0-7964-4575-949c-d0cadcd8b679"
|
| 14 |
+
session_id = str(uuid.uuid4())
|
| 15 |
+
|
| 16 |
+
print(f"Using existing test user: {user_id}")
|
| 17 |
+
|
| 18 |
+
print(f"Creating fresh test session: {session_id}")
|
| 19 |
+
# Insert session
|
| 20 |
+
supabase.table("sessions").insert({
|
| 21 |
+
"id": session_id,
|
| 22 |
+
"user_id": user_id,
|
| 23 |
+
"title": f"Fresh API Test {session_id[:8]}"
|
| 24 |
+
}).execute()
|
| 25 |
+
|
| 26 |
+
# Return IDs for the test script
|
| 27 |
+
print(f"RESULT:USER_ID={user_id}")
|
| 28 |
+
print(f"RESULT:SESSION_ID={session_id}")
|
| 29 |
+
|
| 30 |
+
if __name__ == "__main__":
|
| 31 |
+
prepare()
|
scripts/prewarm_models.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Download and load all heavy models during Docker build (YOLO, PaddleOCR, Pix2Tex, agents).
|
| 4 |
+
Fails the image build if initialization fails.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
|
| 13 |
+
# Ensure imports work when run as `python scripts/prewarm_models.py` from WORKDIR
|
| 14 |
+
_APP_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 15 |
+
if _APP_ROOT not in sys.path:
|
| 16 |
+
sys.path.insert(0, _APP_ROOT)
|
| 17 |
+
|
| 18 |
+
os.chdir(_APP_ROOT)
|
| 19 |
+
|
| 20 |
+
from dotenv import load_dotenv
|
| 21 |
+
|
| 22 |
+
load_dotenv()
|
| 23 |
+
|
| 24 |
+
from app.runtime_env import apply_runtime_env_defaults
|
| 25 |
+
|
| 26 |
+
apply_runtime_env_defaults()
|
| 27 |
+
|
| 28 |
+
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s | %(message)s")
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger("prewarm")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def main() -> None:
|
| 34 |
+
from agents.orchestrator import Orchestrator
|
| 35 |
+
|
| 36 |
+
logger.info("Constructing Orchestrator (full agent + model load)...")
|
| 37 |
+
Orchestrator()
|
| 38 |
+
logger.info("Prewarm finished successfully.")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if __name__ == "__main__":
|
| 42 |
+
main()
|
scripts/test_engine_direct.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import logging
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
# Add root directory to path to import app and agents
|
| 8 |
+
sys.path.append("/Volumes/WorkSpace/Project/MathSolver/backend")
|
| 9 |
+
|
| 10 |
+
# Configure logging to stdout
|
| 11 |
+
logging.basicConfig(level=logging.DEBUG)
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
from agents.orchestrator import Orchestrator
|
| 15 |
+
|
| 16 |
+
async def main():
|
| 17 |
+
orch = Orchestrator()
|
| 18 |
+
text = "Vẽ tam giác đều cạnh 5."
|
| 19 |
+
job_id = "test_direct_equilateral"
|
| 20 |
+
|
| 21 |
+
print(f"\n--- Testing Orchestrator Direct: {text} ---")
|
| 22 |
+
|
| 23 |
+
async def status_cb(status):
|
| 24 |
+
print(f" [STATUS] {status}")
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
result = await orch.run(text, job_id=job_id, status_callback=status_cb, request_video=False)
|
| 28 |
+
print("\n--- Final Result ---")
|
| 29 |
+
print(json.dumps(result, indent=2))
|
| 30 |
+
except Exception as e:
|
| 31 |
+
print(f"\n--- ERROR ---")
|
| 32 |
+
import traceback
|
| 33 |
+
traceback.print_exc()
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
asyncio.run(main())
|
setup.sh
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# MathSolver v3.1 Setup Script for macOS
|
| 4 |
+
|
| 5 |
+
echo "🚀 Starting Environment Setup..."
|
| 6 |
+
|
| 7 |
+
# 1. System Dependencies (Homebrew)
|
| 8 |
+
if command -v brew >/dev/null 2>&1; then
|
| 9 |
+
echo "📦 Installing system dependencies via Homebrew..."
|
| 10 |
+
brew install pango pkg-config glib librsvg
|
| 11 |
+
else
|
| 12 |
+
echo "⚠️ Homebrew not found. Please install it first: https://brew.sh/"
|
| 13 |
+
exit 1
|
| 14 |
+
fi
|
| 15 |
+
|
| 16 |
+
# 2. Python SSL Certificates
|
| 17 |
+
PYTHON_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
|
| 18 |
+
CERT_FILE="/Applications/Python ${PYTHON_VERSION}/Install Certificates.command"
|
| 19 |
+
|
| 20 |
+
if [ -f "$CERT_FILE" ]; then
|
| 21 |
+
echo "🔐 Installing Python SSL certificates..."
|
| 22 |
+
sh "$CERT_FILE"
|
| 23 |
+
else
|
| 24 |
+
echo "ℹ️ SSL certificate installer not found at $CERT_FILE. Skipping..."
|
| 25 |
+
fi
|
| 26 |
+
|
| 27 |
+
# 3. Virtual Environment
|
| 28 |
+
echo "🐍 Setting up Python Virtual Environment..."
|
| 29 |
+
cd backend
|
| 30 |
+
python3 -m venv venv
|
| 31 |
+
source venv/bin/activate
|
| 32 |
+
|
| 33 |
+
# 4. Pip packages
|
| 34 |
+
echo "📦 Installing Python packages..."
|
| 35 |
+
pip install --upgrade pip
|
| 36 |
+
pip install -r requirements.txt
|
| 37 |
+
|
| 38 |
+
# 5. Fix ManimPango (Crucial for macOS arm64)
|
| 39 |
+
echo "🛠️ Rebuilding ManimPango from source to ensure library linking..."
|
| 40 |
+
pip install --no-cache-dir --force-reinstall --no-binary manimpango manimpango
|
| 41 |
+
|
| 42 |
+
echo "✅ Setup Complete!"
|
| 43 |
+
echo "To start the backend, run: source venv/bin/activate && uvicorn app.main:app --reload"
|
solver/dsl_parser.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import logging
|
| 3 |
+
from typing import List, Tuple, Dict, Any
|
| 4 |
+
from .models import Point, Constraint
|
| 5 |
+
|
| 6 |
+
logger = logging.getLogger(__name__)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DSLParser:
|
| 10 |
+
def parse(self, text: str) -> Tuple[List[Point], List[Constraint], bool]:
|
| 11 |
+
"""Parse DSL text into points and constraints. Stateless per call."""
|
| 12 |
+
points: Dict[str, Point] = {}
|
| 13 |
+
explicit_point_ids: List[str] = []
|
| 14 |
+
constraints: List[Constraint] = []
|
| 15 |
+
polygon_order: List[str] = []
|
| 16 |
+
circles: List[Dict[str, Any]] = []
|
| 17 |
+
segments: List[List[str]] = []
|
| 18 |
+
lines_ext: List[List[str]] = []
|
| 19 |
+
rays: List[List[str]] = []
|
| 20 |
+
is_3d = False
|
| 21 |
+
|
| 22 |
+
logger.info("==[DSLParser] Parsing DSL input==")
|
| 23 |
+
logger.debug(f"[DSLParser] Raw DSL:\n{text}")
|
| 24 |
+
|
| 25 |
+
lines = text.strip().split('\n')
|
| 26 |
+
for line in lines:
|
| 27 |
+
line = line.strip()
|
| 28 |
+
if not line or line.startswith('//') or line.startswith('#'):
|
| 29 |
+
continue
|
| 30 |
+
|
| 31 |
+
# POINT(A) or POINT(A, 0, 0, 5)
|
| 32 |
+
m = re.match(r'POINT\((\w+)(?:,\s*([\d\.-]+),\s*([\d\.-]+)(?:,\s*([\d\.-]+))?)?\)', line)
|
| 33 |
+
if m:
|
| 34 |
+
name = m.group(1)
|
| 35 |
+
x = float(m.group(2)) if m.group(2) else None
|
| 36 |
+
y = float(m.group(3)) if m.group(3) else None
|
| 37 |
+
z = float(m.group(4)) if m.group(4) else None
|
| 38 |
+
if z is not None:
|
| 39 |
+
is_3d = True
|
| 40 |
+
points[name] = Point(id=name, x=x, y=y, z=z)
|
| 41 |
+
if name not in explicit_point_ids:
|
| 42 |
+
explicit_point_ids.append(name)
|
| 43 |
+
logger.debug(f"[DSLParser] + POINT: {name} ({x}, {y}, {z})")
|
| 44 |
+
continue
|
| 45 |
+
|
| 46 |
+
# LENGTH(AB, 5)
|
| 47 |
+
m = re.match(r'LENGTH\((\w+),\s*([\d\.]+)\)', line)
|
| 48 |
+
if m:
|
| 49 |
+
target, value = m.group(1), float(m.group(2))
|
| 50 |
+
pts = [target[i:i+1] for i in range(len(target))]
|
| 51 |
+
constraints.append(Constraint(type='length', targets=pts, value=value))
|
| 52 |
+
logger.debug(f"[DSLParser] + LENGTH: {pts} = {value}")
|
| 53 |
+
continue
|
| 54 |
+
|
| 55 |
+
# ANGLE(A, 90) or ANGLE(A, 90deg)
|
| 56 |
+
m = re.match(r'ANGLE\((\w+),\s*([\d\.]+)(?:deg)?\)', line)
|
| 57 |
+
if m:
|
| 58 |
+
target, value = m.group(1), float(m.group(2))
|
| 59 |
+
constraints.append(Constraint(type='angle', targets=[target], value=value))
|
| 60 |
+
logger.debug(f"[DSLParser] + ANGLE: vertex={target}, degrees={value}")
|
| 61 |
+
continue
|
| 62 |
+
|
| 63 |
+
# PARALLEL(AB, CD)
|
| 64 |
+
m = re.match(r'PARALLEL\((\w+),\s*(\w+)\)', line)
|
| 65 |
+
if m:
|
| 66 |
+
seg1, seg2 = m.group(1), m.group(2)
|
| 67 |
+
constraints.append(Constraint(type='parallel', targets=list(seg1) + list(seg2), value=0))
|
| 68 |
+
logger.debug(f"[DSLParser] + PARALLEL: {seg1} || {seg2}")
|
| 69 |
+
continue
|
| 70 |
+
|
| 71 |
+
# PERPENDICULAR(AB, CD)
|
| 72 |
+
m = re.match(r'PERPENDICULAR\((\w+),\s*(\w+)\)', line)
|
| 73 |
+
if m:
|
| 74 |
+
seg1, seg2 = m.group(1), m.group(2)
|
| 75 |
+
constraints.append(Constraint(type='perpendicular', targets=list(seg1) + list(seg2), value=0))
|
| 76 |
+
logger.debug(f"[DSLParser] + PERPENDICULAR: {seg1} _|_ {seg2}")
|
| 77 |
+
continue
|
| 78 |
+
|
| 79 |
+
# MIDPOINT(M, AB) — M is midpoint of AB
|
| 80 |
+
m = re.match(r'MIDPOINT\((\w+),\s*(\w+)\)', line)
|
| 81 |
+
if m:
|
| 82 |
+
mid, seg = m.group(1), m.group(2)
|
| 83 |
+
if mid not in points:
|
| 84 |
+
points[mid] = Point(id=mid)
|
| 85 |
+
pts = [mid] + [seg[i:i+1] for i in range(len(seg))]
|
| 86 |
+
constraints.append(Constraint(type='midpoint', targets=pts, value=0))
|
| 87 |
+
logger.debug(f"[DSLParser] + MIDPOINT: {mid} = mid({seg})")
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
# SECTION(E, A, C, 0.66) — E lies on AC s.t. AE = 0.66 * AC
|
| 91 |
+
m = re.match(r'SECTION\((\w+),\s*(\w+),\s*(\w+),\s*([\d\.-]+)\)', line)
|
| 92 |
+
if m:
|
| 93 |
+
target, p1, p2, k = m.group(1), m.group(2), m.group(3), float(m.group(4))
|
| 94 |
+
if target not in points:
|
| 95 |
+
points[target] = Point(id=target)
|
| 96 |
+
constraints.append(Constraint(type='section', targets=[target, p1, p2], value=k))
|
| 97 |
+
logger.debug(f"[DSLParser] + SECTION: {target} = {p1} + {k}({p2}-{p1})")
|
| 98 |
+
continue
|
| 99 |
+
|
| 100 |
+
# CIRCLE(O, r)
|
| 101 |
+
m = re.match(r'CIRCLE\((\w+),\s*([\d\.]+)\)', line)
|
| 102 |
+
if m:
|
| 103 |
+
center, radius = m.group(1), float(m.group(2))
|
| 104 |
+
if center not in points:
|
| 105 |
+
points[center] = Point(id=center)
|
| 106 |
+
constraints.append(Constraint(type='circle', targets=[center], value=radius))
|
| 107 |
+
circles.append({"center": center, "radius": radius})
|
| 108 |
+
logger.debug(f"[DSLParser] + CIRCLE: center={center}, r={radius}")
|
| 109 |
+
continue
|
| 110 |
+
|
| 111 |
+
# POLYGON_ORDER(A, B, C, D) — thứ tự nối điểm để vẽ đa giác
|
| 112 |
+
m = re.match(r'POLYGON_ORDER\(([^)]+)\)', line)
|
| 113 |
+
if m:
|
| 114 |
+
polygon_order = [p.strip() for p in m.group(1).split(',')]
|
| 115 |
+
logger.debug(f"[DSLParser] + POLYGON_ORDER: {polygon_order}")
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
# SEGMENT(M, N) — đoạn thẳng phụ cần vẽ
|
| 119 |
+
m = re.match(r'SEGMENT\((\w+),\s*(\w+)\)', line)
|
| 120 |
+
if m:
|
| 121 |
+
p1, p2 = m.group(1), m.group(2)
|
| 122 |
+
segments.append([p1, p2])
|
| 123 |
+
constraints.append(Constraint(type='segment', targets=[p1, p2], value=0))
|
| 124 |
+
logger.debug(f"[DSLParser] + SEGMENT: {p1}—{p2}")
|
| 125 |
+
continue
|
| 126 |
+
|
| 127 |
+
# LINE(A, B) — infinite line
|
| 128 |
+
m = re.match(r'LINE\((\w+),\s*(\w+)\)', line)
|
| 129 |
+
if m:
|
| 130 |
+
p1, p2 = m.group(1), m.group(2)
|
| 131 |
+
lines_ext.append([p1, p2])
|
| 132 |
+
constraints.append(Constraint(type='line', targets=[p1, p2], value=0))
|
| 133 |
+
logger.debug(f"[DSLParser] + LINE: {p1}-{p2}")
|
| 134 |
+
continue
|
| 135 |
+
|
| 136 |
+
# RAY(A, B) — ray AB starting at A
|
| 137 |
+
m = re.match(r'RAY\((\w+),\s*(\w+)\)', line)
|
| 138 |
+
if m:
|
| 139 |
+
p1, p2 = m.group(1), m.group(2)
|
| 140 |
+
rays.append([p1, p2])
|
| 141 |
+
constraints.append(Constraint(type='ray', targets=[p1, p2], value=0))
|
| 142 |
+
logger.debug(f"[DSLParser] + RAY: {p1}->{p2}")
|
| 143 |
+
continue
|
| 144 |
+
|
| 145 |
+
# TRIANGLE(ABC) / PYRAMID(S_ABCD) / PRISM(ABC_DEF)
|
| 146 |
+
m = re.match(r'(TRIANGLE|PYRAMID|PRISM)\(([^)]+)\)', line)
|
| 147 |
+
if m:
|
| 148 |
+
pt_type = m.group(1)
|
| 149 |
+
targets = m.group(2)
|
| 150 |
+
if pt_type in ["PYRAMID", "PRISM"]:
|
| 151 |
+
is_3d = True
|
| 152 |
+
if pt_type == "TRIANGLE":
|
| 153 |
+
if not polygon_order: polygon_order = list(targets)
|
| 154 |
+
elif pt_type == "PYRAMID":
|
| 155 |
+
# S_ABCD -> S is apex, ABCD is base
|
| 156 |
+
if "_" in targets:
|
| 157 |
+
apex, base = targets.split("_")
|
| 158 |
+
# Add segments from apex to all base points
|
| 159 |
+
for p in base:
|
| 160 |
+
segments.append([apex, p])
|
| 161 |
+
constraints.append(Constraint(type='segment', targets=[apex, p], value=0))
|
| 162 |
+
if not polygon_order: polygon_order = list(base)
|
| 163 |
+
elif pt_type == "PRISM":
|
| 164 |
+
# ABC_DEF -> two bases
|
| 165 |
+
if "_" in targets:
|
| 166 |
+
b1, b2 = targets.split("_")
|
| 167 |
+
for p1, p2 in zip(b1, b2):
|
| 168 |
+
segments.append([p1, p2])
|
| 169 |
+
constraints.append(Constraint(type='segment', targets=[p1, p2], value=0))
|
| 170 |
+
logger.debug(f"[DSLParser] + {pt_type}: {targets}")
|
| 171 |
+
continue
|
| 172 |
+
|
| 173 |
+
# SPHERE(O, r)
|
| 174 |
+
m = re.match(r'SPHERE\((\w+),\s*([\d\.]+)\)', line)
|
| 175 |
+
if m:
|
| 176 |
+
is_3d = True
|
| 177 |
+
center, radius = m.group(1), float(m.group(2))
|
| 178 |
+
if center not in points:
|
| 179 |
+
points[center] = Point(id=center)
|
| 180 |
+
constraints.append(Constraint(type='sphere', targets=[center], value=radius))
|
| 181 |
+
logger.debug(f"[DSLParser] + SPHERE: center={center}, r={radius}")
|
| 182 |
+
continue
|
| 183 |
+
|
| 184 |
+
logger.warning(f"[DSLParser] ? Unrecognized DSL line: '{line}'")
|
| 185 |
+
|
| 186 |
+
logger.info(f"[DSLParser] Parsed {len(points)} points, {len(constraints)} constraints.")
|
| 187 |
+
|
| 188 |
+
# Safety sweep: Ensure all points referenced in constraints actually exist in the points dictionary
|
| 189 |
+
for c in constraints:
|
| 190 |
+
for pid in c.targets:
|
| 191 |
+
# Some targets might be values or comma-separated strings (handled elsewhere),
|
| 192 |
+
# but most are single-character point IDs.
|
| 193 |
+
if isinstance(pid, str) and len(pid) == 1 and pid not in points:
|
| 194 |
+
points[pid] = Point(id=pid)
|
| 195 |
+
logger.debug(f"[DSLParser] ! Auto-declared missing point from constraint: {pid}")
|
| 196 |
+
|
| 197 |
+
# Attach metadata to a synthetic constraint for downstream use
|
| 198 |
+
if polygon_order:
|
| 199 |
+
constraints.append(Constraint(type='polygon_order', targets=polygon_order, value=0))
|
| 200 |
+
elif explicit_point_ids:
|
| 201 |
+
# Re-use polygon_order as a carrier for explicit points IF no real order was specified
|
| 202 |
+
constraints.append(Constraint(type='explicit_points', targets=explicit_point_ids, value=0))
|
| 203 |
+
|
| 204 |
+
# Add auxiliary metadata for lines and rays
|
| 205 |
+
if lines_ext:
|
| 206 |
+
constraints.append(Constraint(type='lines_metadata', targets=[",".join(l) for l in lines_ext], value=0))
|
| 207 |
+
if rays:
|
| 208 |
+
constraints.append(Constraint(type='rays_metadata', targets=[",".join(l) for l in rays], value=0))
|
| 209 |
+
|
| 210 |
+
return list(points.values()), constraints, is_3d
|
solver/engine.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sympy as sp
|
| 2 |
+
import numpy as np
|
| 3 |
+
import logging
|
| 4 |
+
import string
|
| 5 |
+
from typing import List, Dict, Any
|
| 6 |
+
from .models import Point, Constraint
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class GeometryEngine:
|
| 12 |
+
def solve(self, points: List[Point], constraints: List[Constraint], is_3d: bool = False) -> Dict[str, Any] | None:
|
| 13 |
+
if not points:
|
| 14 |
+
logger.error("[GeometryEngine] No points to solve.")
|
| 15 |
+
return None
|
| 16 |
+
|
| 17 |
+
logger.info(f"==[GeometryEngine] Starting solve with {len(points)} points, {len(constraints)} constraints (is_3d={is_3d})==")
|
| 18 |
+
|
| 19 |
+
# ── Separate metadata constraints from real ones ──────────────────────
|
| 20 |
+
polygon_order: List[str] = []
|
| 21 |
+
circles_meta: List[Dict] = []
|
| 22 |
+
segments_meta: List[List[str]] = []
|
| 23 |
+
real_constraints: List[Constraint] = []
|
| 24 |
+
|
| 25 |
+
for c in constraints:
|
| 26 |
+
if c.type == 'polygon_order':
|
| 27 |
+
polygon_order = list(c.targets)
|
| 28 |
+
elif c.type == 'explicit_points' and not polygon_order:
|
| 29 |
+
polygon_order = list(c.targets)
|
| 30 |
+
elif c.type == 'circle':
|
| 31 |
+
circles_meta.append({"center": c.targets[0], "radius": float(c.value)})
|
| 32 |
+
real_constraints.append(c)
|
| 33 |
+
elif c.type == 'segment':
|
| 34 |
+
segments_meta.append(list(c.targets))
|
| 35 |
+
# don't add to equations — pure drawing annotation
|
| 36 |
+
elif c.type == 'lines_metadata':
|
| 37 |
+
lines_meta_list = [t.split(',') for t in c.targets]
|
| 38 |
+
real_constraints.append(c) # for passing to builder? or just keep here
|
| 39 |
+
elif c.type == 'rays_metadata':
|
| 40 |
+
rays_meta_list = [t.split(',') for t in c.targets]
|
| 41 |
+
real_constraints.append(c)
|
| 42 |
+
else:
|
| 43 |
+
real_constraints.append(c)
|
| 44 |
+
|
| 45 |
+
# ── Setup symbols ─────────────────────────────────────────────────────
|
| 46 |
+
point_vars: Dict[str, tuple] = {}
|
| 47 |
+
equations = []
|
| 48 |
+
|
| 49 |
+
# Convert to list for stable indexing and to handle both Dict and List inputs
|
| 50 |
+
pt_list = list(points.values()) if isinstance(points, dict) else points
|
| 51 |
+
|
| 52 |
+
for p in pt_list:
|
| 53 |
+
x = sp.Symbol(f"{p.id}_x")
|
| 54 |
+
y = sp.Symbol(f"{p.id}_y")
|
| 55 |
+
z = sp.Symbol(f"{p.id}_z")
|
| 56 |
+
point_vars[p.id] = (x, y, z)
|
| 57 |
+
logger.debug(f"[GeometryEngine] Symbol: ({p.id}_x, {p.id}_y, {p.id}_z)")
|
| 58 |
+
|
| 59 |
+
# If 2D problem, pin all z to 0 immediately
|
| 60 |
+
if not is_3d:
|
| 61 |
+
equations.append(z)
|
| 62 |
+
|
| 63 |
+
# ── Anchor logic to fix translation + rotation DOF ────────────────────
|
| 64 |
+
# Skip anchoring if points already have explicit coordinates that fix DOFs
|
| 65 |
+
|
| 66 |
+
if len(pt_list) > 0:
|
| 67 |
+
p1 = pt_list[0]
|
| 68 |
+
# Translation: fix p1 at (0,0) or (0,0,0)
|
| 69 |
+
if p1.x is None: equations.append(point_vars[p1.id][0]); logger.debug(f"Anchor {p1.id}_x=0")
|
| 70 |
+
if p1.y is None: equations.append(point_vars[p1.id][1]); logger.debug(f"Anchor {p1.id}_y=0")
|
| 71 |
+
if is_3d and p1.z is None:
|
| 72 |
+
equations.append(point_vars[p1.id][2]); logger.debug(f"Anchor {p1.id}_z=0")
|
| 73 |
+
|
| 74 |
+
if len(pt_list) > 1:
|
| 75 |
+
p2 = pt_list[1]
|
| 76 |
+
# Rotation: fix p2 on X-axis (y=0)
|
| 77 |
+
if p2.y is None: equations.append(point_vars[p2.id][1]); logger.debug(f"Anchor {p2.id}_y=0")
|
| 78 |
+
if is_3d and p2.z is None:
|
| 79 |
+
equations.append(point_vars[p2.id][2]); logger.debug(f"Anchor {p2.id}_z=0")
|
| 80 |
+
|
| 81 |
+
if is_3d and len(pt_list) > 2:
|
| 82 |
+
p3 = pt_list[2]
|
| 83 |
+
# Planar rotation: fix p3 on XY-plane (z=0)
|
| 84 |
+
if p3.z is None: equations.append(point_vars[p3.id][2]); logger.debug(f"Anchor {p3.id}_z=0")
|
| 85 |
+
|
| 86 |
+
# ── Build equations from explicit point coordinates ──────────────────
|
| 87 |
+
for p in pt_list:
|
| 88 |
+
if p.x is not None:
|
| 89 |
+
equations.append(point_vars[p.id][0] - p.x)
|
| 90 |
+
if p.y is not None:
|
| 91 |
+
equations.append(point_vars[p.id][1] - p.y)
|
| 92 |
+
if p.z is not None:
|
| 93 |
+
equations.append(point_vars[p.id][2] - p.z)
|
| 94 |
+
|
| 95 |
+
# ── Build equations from constraints ──────────────────────────────────
|
| 96 |
+
for c in real_constraints:
|
| 97 |
+
logger.debug(f"[GeometryEngine] Processing constraint: type={c.type}, targets={c.targets}, value={c.value}")
|
| 98 |
+
|
| 99 |
+
if c.type == 'length' and len(c.targets) == 2:
|
| 100 |
+
p1, p2 = c.targets
|
| 101 |
+
if p1 not in point_vars or p2 not in point_vars:
|
| 102 |
+
logger.warning(f"[GeometryEngine] Skip length: {c.targets} not in symbols.")
|
| 103 |
+
continue
|
| 104 |
+
v1, v2 = point_vars[p1], point_vars[p2]
|
| 105 |
+
# 3D distance
|
| 106 |
+
eq = (v2[0]-v1[0])**2 + (v2[1]-v1[1])**2 + (v2[2]-v1[2])**2 - float(c.value)**2
|
| 107 |
+
equations.append(eq)
|
| 108 |
+
logger.debug(f"[GeometryEngine] -> Length eq (3D): |{p1}{p2}|² = {c.value}²")
|
| 109 |
+
|
| 110 |
+
elif c.type == 'angle' and len(c.targets) >= 1:
|
| 111 |
+
# In 3D, 'angle' usually refers to the angle between two vectors (e.g., ∠BAC)
|
| 112 |
+
v_name = c.targets[0]
|
| 113 |
+
if v_name not in point_vars:
|
| 114 |
+
continue
|
| 115 |
+
# For simplicity, we assume the next two points in targets or fallback to first 2 others
|
| 116 |
+
if len(c.targets) >= 3:
|
| 117 |
+
p1_name, p2_name = c.targets[1], c.targets[2]
|
| 118 |
+
else:
|
| 119 |
+
other_pts = [p.id for p in pt_list if p.id != v_name][:2]
|
| 120 |
+
if len(other_pts) < 2: continue
|
| 121 |
+
p1_name, p2_name = other_pts
|
| 122 |
+
|
| 123 |
+
pV = point_vars[v_name]
|
| 124 |
+
p1_vars = point_vars[p1_name]
|
| 125 |
+
p2_vars = point_vars[p2_name]
|
| 126 |
+
|
| 127 |
+
# Vectors V1 and V2
|
| 128 |
+
v1 = [p1_vars[i] - pV[i] for i in range(3)]
|
| 129 |
+
v2 = [p2_vars[i] - pV[i] for i in range(3)]
|
| 130 |
+
|
| 131 |
+
# Dot product relation: v1.v2 = |v1||v2| cos(theta)
|
| 132 |
+
# But we use the tangent relation or square it to avoid sqrt if possible
|
| 133 |
+
# If 90 deg: dot product = 0
|
| 134 |
+
if abs(float(c.value) - 90.0) < 1e-9:
|
| 135 |
+
eq = sum(v1[i]*v2[i] for i in range(3))
|
| 136 |
+
logger.debug(f"[GeometryEngine] -> Angle eq at {v_name} (90° dot=0)")
|
| 137 |
+
else:
|
| 138 |
+
# Generic angle using law of cosines (squared)
|
| 139 |
+
cos_val = np.cos(np.deg2rad(float(c.value)))
|
| 140 |
+
d1_sq = sum(v1[i]**2 for i in range(3))
|
| 141 |
+
d2_sq = sum(v2[i]**2 for i in range(3))
|
| 142 |
+
dot = sum(v1[i]*v2[i] for i in range(3))
|
| 143 |
+
eq = dot**2 - (cos_val**2) * d1_sq * d2_sq
|
| 144 |
+
# Note: this allows theta and 180-theta.
|
| 145 |
+
# Better: dot - cos(theta) * sqrt(d1_sq * d2_sq) = 0, but that has sqrt.
|
| 146 |
+
logger.debug(f"[GeometryEngine] -> Angle eq at {v_name} ({c.value}° cos² relation)")
|
| 147 |
+
equations.append(eq)
|
| 148 |
+
|
| 149 |
+
elif c.type == 'parallel' and len(c.targets) == 4:
|
| 150 |
+
pA, pB, pC, pD = c.targets
|
| 151 |
+
if any(t not in point_vars for t in [pA, pB, pC, pD]): continue
|
| 152 |
+
va, vb, vc, vd = point_vars[pA], point_vars[pB], point_vars[pC], point_vars[pD]
|
| 153 |
+
# AB || CD means vector(AB) = lambda * vector(CD)
|
| 154 |
+
# In 3D, cross product = 0. (b-a) x (d-c) = 0
|
| 155 |
+
v1 = [vb[i]-va[i] for i in range(3)]
|
| 156 |
+
v2 = [vd[i]-vc[i] for i in range(3)]
|
| 157 |
+
# Cross product components:
|
| 158 |
+
equations.append(v1[1]*v2[2] - v1[2]*v2[1])
|
| 159 |
+
equations.append(v1[2]*v2[0] - v1[0]*v2[2])
|
| 160 |
+
equations.append(v1[0]*v2[1] - v1[1]*v2[0])
|
| 161 |
+
logger.debug(f"[GeometryEngine] -> Parallel eq (3D cross=0): {pA}{pB} || {pC}{pD}")
|
| 162 |
+
|
| 163 |
+
elif c.type == 'perpendicular' and len(c.targets) == 4:
|
| 164 |
+
pA, pB, pC, pD = c.targets
|
| 165 |
+
if any(t not in point_vars for t in [pA, pB, pC, pD]): continue
|
| 166 |
+
va, vb, vc, vd = point_vars[pA], point_vars[pB], point_vars[pC], point_vars[pD]
|
| 167 |
+
# Dot product = 0
|
| 168 |
+
dot = sum((vb[i]-va[i])*(vd[i]-vc[i]) for i in range(3))
|
| 169 |
+
equations.append(dot)
|
| 170 |
+
logger.debug(f"[GeometryEngine] -> Perpendicular eq (3D dot=0): {pA}{pB} ⊥ {pC}{pD}")
|
| 171 |
+
|
| 172 |
+
elif c.type == 'midpoint' and len(c.targets) == 3:
|
| 173 |
+
pM, pA, pB = c.targets
|
| 174 |
+
if any(t not in point_vars for t in [pM, pA, pB]): continue
|
| 175 |
+
vM, vA, vB = point_vars[pM], point_vars[pA], point_vars[pB]
|
| 176 |
+
for i in range(3):
|
| 177 |
+
equations.append(2*vM[i] - vA[i] - vB[i])
|
| 178 |
+
logger.debug(f"[GeometryEngine] -> Midpoint eq (3D): {pM} = mid({pA},{pB})")
|
| 179 |
+
|
| 180 |
+
elif c.type == 'section' and len(c.targets) == 3:
|
| 181 |
+
pE, pA, pC = c.targets
|
| 182 |
+
if any(t not in point_vars for t in [pE, pA, pC]): continue
|
| 183 |
+
vE, vA, vC = point_vars[pE], point_vars[pA], point_vars[pC]
|
| 184 |
+
k = float(c.value)
|
| 185 |
+
for i in range(3):
|
| 186 |
+
equations.append(vE[i] - (vA[i] + k * (vC[i] - vA[i])))
|
| 187 |
+
logger.debug(f"[GeometryEngine] -> Section eq (3D): {pE} = {pA} + {k}({pC}-{pA})")
|
| 188 |
+
|
| 189 |
+
elif c.type == 'circle':
|
| 190 |
+
# Circle doesn't add position constraints for center (already a point)
|
| 191 |
+
logger.debug(f"[GeometryEngine] -> Circle: center={c.targets[0]}, r={c.value} (meta only)")
|
| 192 |
+
|
| 193 |
+
all_vars = []
|
| 194 |
+
for v in point_vars.values():
|
| 195 |
+
all_vars.extend(v)
|
| 196 |
+
|
| 197 |
+
n_eqs = len(equations)
|
| 198 |
+
n_vars = len(all_vars)
|
| 199 |
+
logger.info(f"[GeometryEngine] Built {n_eqs} equations for {n_vars} unknowns.")
|
| 200 |
+
|
| 201 |
+
# ── Strategy 1: SymPy symbolic ───────────────────────────────────────
|
| 202 |
+
coords = self._try_symbolic(equations, all_vars, point_vars)
|
| 203 |
+
|
| 204 |
+
# Extract lines/rays from constraints for builder
|
| 205 |
+
lines_ext = []
|
| 206 |
+
rays_ext = []
|
| 207 |
+
for c in constraints:
|
| 208 |
+
if c.type == 'lines_metadata':
|
| 209 |
+
lines_ext = [t.split(',') for t in c.targets]
|
| 210 |
+
if c.type == 'rays_metadata':
|
| 211 |
+
rays_ext = [t.split(',') for t in c.targets]
|
| 212 |
+
|
| 213 |
+
if coords:
|
| 214 |
+
return self._build_result(coords, polygon_order, circles_meta, segments_meta, lines_ext, rays_ext, pt_list)
|
| 215 |
+
|
| 216 |
+
# ── Strategy 2: Numerical nsolve ─────────────────────────────────────
|
| 217 |
+
if n_eqs == n_vars:
|
| 218 |
+
coords = self._try_nsolve(equations, all_vars, point_vars, n_vars)
|
| 219 |
+
if coords:
|
| 220 |
+
return self._build_result(coords, polygon_order, circles_meta, segments_meta, lines_ext, rays_ext, pt_list)
|
| 221 |
+
|
| 222 |
+
# ── Strategy 3: Scipy least-squares ─────────────────────────────────
|
| 223 |
+
coords = self._try_lsq(equations, all_vars, point_vars, n_vars)
|
| 224 |
+
if coords:
|
| 225 |
+
return self._build_result(coords, polygon_order, circles_meta, segments_meta, lines_ext, rays_ext, pt_list)
|
| 226 |
+
|
| 227 |
+
# ── Strategy 4: Differential evolution ──────────────────────────────
|
| 228 |
+
coords = self._try_global(equations, all_vars, point_vars, n_vars)
|
| 229 |
+
if coords:
|
| 230 |
+
return self._build_result(coords, polygon_order, circles_meta, segments_meta, lines_ext, rays_ext, pt_list)
|
| 231 |
+
|
| 232 |
+
logger.error("[GeometryEngine] All strategies exhausted.")
|
| 233 |
+
return None
|
| 234 |
+
|
| 235 |
+
# ─── Solving strategies ──────────────────────────────────────────────────
|
| 236 |
+
|
| 237 |
+
def _try_symbolic(self, equations, all_vars, point_vars):
|
| 238 |
+
# Optimization: SymPy's symbolic solver becomes extremely slow for many variables.
|
| 239 |
+
# For 3D problems (usually 12-18+ variables), we prefer using numerical methods directly.
|
| 240 |
+
if len(all_vars) > 10:
|
| 241 |
+
logger.info(f"[GeometryEngine] Strategy 1: Skipping symbolic solve due to high variable count ({len(all_vars)}).")
|
| 242 |
+
return None
|
| 243 |
+
|
| 244 |
+
try:
|
| 245 |
+
solution = sp.solve(equations, all_vars, dict=True)
|
| 246 |
+
if solution:
|
| 247 |
+
res = solution[0]
|
| 248 |
+
logger.info("[GeometryEngine] Strategy 1 (SymPy symbolic): SUCCESS.")
|
| 249 |
+
logger.debug(f"[GeometryEngine] Symbolic solution: {res}")
|
| 250 |
+
return {pid: [float(res.get(vx, 0.0)), float(res.get(vy, 0.0)), float(res.get(vz, 0.0))]
|
| 251 |
+
for pid, (vx, vy, vz) in point_vars.items()}
|
| 252 |
+
else:
|
| 253 |
+
logger.warning("[GeometryEngine] Strategy 1 returned no solution. Trying numerical...")
|
| 254 |
+
except Exception as e:
|
| 255 |
+
logger.warning(f"[GeometryEngine] Strategy 1 threw exception: {e}. Trying numerical...")
|
| 256 |
+
return None
|
| 257 |
+
|
| 258 |
+
def _try_nsolve(self, equations, all_vars, point_vars, n_vars):
|
| 259 |
+
MAX_NSOLVE_ATTEMPTS = 15
|
| 260 |
+
logger.info(f"[GeometryEngine] Strategy 2 (nsolve): square system ({n_vars}x{n_vars}). Trying {MAX_NSOLVE_ATTEMPTS} random starts...")
|
| 261 |
+
import random
|
| 262 |
+
for attempt in range(MAX_NSOLVE_ATTEMPTS):
|
| 263 |
+
try:
|
| 264 |
+
# Use varying scales for the random guesses to handle different problem sizes
|
| 265 |
+
scale = 10 if attempt < 5 else (100 if attempt < 10 else 1)
|
| 266 |
+
guesses = [random.uniform(-scale, scale) for _ in all_vars]
|
| 267 |
+
sol_vals = sp.nsolve(equations, all_vars, guesses, tol=1e-6, maxsteps=1000)
|
| 268 |
+
res = {var: float(val) for var, val in zip(all_vars, sol_vals)}
|
| 269 |
+
logger.info(f"[GeometryEngine] Strategy 2 (nsolve): SUCCESS on attempt {attempt + 1}.")
|
| 270 |
+
return {pid: [float(res.get(vx, 0.0)), float(res.get(vy, 0.0)), float(res.get(vz, 0.0))]
|
| 271 |
+
for pid, (vx, vy, vz) in point_vars.items()}
|
| 272 |
+
except Exception as e:
|
| 273 |
+
logger.debug(f"[GeometryEngine] nsolve attempt {attempt + 1} failed: {e}")
|
| 274 |
+
return None
|
| 275 |
+
|
| 276 |
+
def _try_lsq(self, equations, all_vars, point_vars, n_vars):
|
| 277 |
+
logger.info("[GeometryEngine] Strategy 3 (scipy least-squares): minimizing residuals...")
|
| 278 |
+
try:
|
| 279 |
+
from scipy.optimize import minimize
|
| 280 |
+
eq_funcs = [sp.lambdify(all_vars, eq, 'numpy') for eq in equations]
|
| 281 |
+
|
| 282 |
+
def objective(x):
|
| 283 |
+
return sum(float(f(*x))**2 for f in eq_funcs)
|
| 284 |
+
|
| 285 |
+
best_res, best_val = None, float('inf')
|
| 286 |
+
# Increase restarts for better coverage of local minima
|
| 287 |
+
for i in range(12):
|
| 288 |
+
if i == 0:
|
| 289 |
+
x0 = [1.0]*n_vars
|
| 290 |
+
elif i < 4:
|
| 291 |
+
x0 = [np.random.uniform(-10, 10) for _ in range(n_vars)]
|
| 292 |
+
else:
|
| 293 |
+
x0 = [np.random.uniform(-100, 100) for _ in range(n_vars)]
|
| 294 |
+
|
| 295 |
+
res = minimize(objective, x0, method='L-BFGS-B')
|
| 296 |
+
if res.fun < best_val:
|
| 297 |
+
best_val, best_res = res.fun, res
|
| 298 |
+
if best_val < 1e-6:
|
| 299 |
+
break
|
| 300 |
+
|
| 301 |
+
TOLERANCE = 1e-4
|
| 302 |
+
logger.info(f"[GeometryEngine] Strategy 3: best residual = {best_val:.2e} (tol={TOLERANCE})")
|
| 303 |
+
if best_val < TOLERANCE:
|
| 304 |
+
res = {var: float(val) for var, val in zip(all_vars, best_res.x)}
|
| 305 |
+
logger.info("[GeometryEngine] Strategy 3 (least-squares): SUCCESS.")
|
| 306 |
+
return {pid: [float(res.get(vx, 0)), float(res.get(vy, 0)), float(res.get(vz, 0))]
|
| 307 |
+
for pid, (vx, vy, vz) in point_vars.items()}
|
| 308 |
+
else:
|
| 309 |
+
logger.warning(f"[GeometryEngine] Strategy 3 failed: residual {best_val:.2e} > {TOLERANCE}")
|
| 310 |
+
except Exception as e:
|
| 311 |
+
logger.error(f"[GeometryEngine] Strategy 3 threw exception: {e}")
|
| 312 |
+
return None
|
| 313 |
+
|
| 314 |
+
def _try_global(self, equations, all_vars, point_vars, n_vars):
|
| 315 |
+
logger.info("[GeometryEngine] Strategy 4 (Differential Evolution): global search...")
|
| 316 |
+
try:
|
| 317 |
+
from scipy.optimize import differential_evolution
|
| 318 |
+
bounds = [(-20, 20)] * n_vars
|
| 319 |
+
eq_funcs = [sp.lambdify(all_vars, eq, 'numpy') for eq in equations]
|
| 320 |
+
|
| 321 |
+
def obj(x):
|
| 322 |
+
s = 0.0
|
| 323 |
+
for f in eq_funcs:
|
| 324 |
+
try:
|
| 325 |
+
s += float(f(*x))**2
|
| 326 |
+
except:
|
| 327 |
+
s += 1e6
|
| 328 |
+
return s
|
| 329 |
+
|
| 330 |
+
result = differential_evolution(obj, bounds, maxiter=500, popsize=15, mutation=(0.5, 1), recombination=0.7)
|
| 331 |
+
TOLERANCE = 1e-3
|
| 332 |
+
logger.info(f"[GeometryEngine] Strategy 4: best residual = {result.fun:.2e} (tol={TOLERANCE})")
|
| 333 |
+
if result.fun < TOLERANCE:
|
| 334 |
+
res = {var: float(val) for var, val in zip(all_vars, result.x)}
|
| 335 |
+
logger.info("[GeometryEngine] Strategy 4 (global opt): SUCCESS.")
|
| 336 |
+
return {pid: [float(res.get(vx, 0)), float(res.get(vy, 0)), float(res.get(vz, 0))]
|
| 337 |
+
for pid, (vx, vy, vz) in point_vars.items()}
|
| 338 |
+
except Exception as e:
|
| 339 |
+
logger.error(f"[GeometryEngine] Strategy 4 threw exception: {e}")
|
| 340 |
+
return None
|
| 341 |
+
|
| 342 |
+
# ─── Result builder ──────────────────────────────────────────────────────
|
| 343 |
+
|
| 344 |
+
def _build_result(
|
| 345 |
+
self,
|
| 346 |
+
coords: Dict[str, List[float]],
|
| 347 |
+
polygon_order: List[str],
|
| 348 |
+
circles_meta: List[Dict],
|
| 349 |
+
segments_meta: List[List[str]],
|
| 350 |
+
lines_meta: List[List[str]],
|
| 351 |
+
rays_meta: List[List[str]],
|
| 352 |
+
pt_list: List[Point],
|
| 353 |
+
) -> Dict[str, Any]:
|
| 354 |
+
"""
|
| 355 |
+
Build structured result including drawing phases for the renderer.
|
| 356 |
+
|
| 357 |
+
drawing_phases:
|
| 358 |
+
Phase 1 — Base shape (main polygon)
|
| 359 |
+
Phase 2 — Auxiliary/derived points and segments
|
| 360 |
+
"""
|
| 361 |
+
all_ids = [p.id for p in pt_list]
|
| 362 |
+
|
| 363 |
+
# 1. Infer/clean polygon_order
|
| 364 |
+
if not polygon_order:
|
| 365 |
+
# Fallback: use all declared point IDs sorted by conventional uppercase order.
|
| 366 |
+
# This is far safer than only looking for A/B/C/D.
|
| 367 |
+
base_pts = sorted(
|
| 368 |
+
all_ids,
|
| 369 |
+
key=lambda p: (string.ascii_uppercase.index(p) if p in string.ascii_uppercase else 100, p)
|
| 370 |
+
)
|
| 371 |
+
polygon_order = base_pts
|
| 372 |
+
|
| 373 |
+
base_ids = [pid for pid in polygon_order if pid in all_ids]
|
| 374 |
+
derived_ids = [pid for pid in all_ids if pid not in polygon_order]
|
| 375 |
+
|
| 376 |
+
# 2. Collect unique segments to avoid redundancy (AB == BA)
|
| 377 |
+
drawn_segments = set()
|
| 378 |
+
|
| 379 |
+
def add_segment(p1, p2, target_list):
|
| 380 |
+
if p1 == p2:
|
| 381 |
+
return
|
| 382 |
+
s = frozenset([p1, p2])
|
| 383 |
+
if s not in drawn_segments:
|
| 384 |
+
drawn_segments.add(s)
|
| 385 |
+
target_list.append([p1, p2])
|
| 386 |
+
|
| 387 |
+
# Phase 1: Main polygon boundary
|
| 388 |
+
phase1_segments = []
|
| 389 |
+
if len(base_ids) >= 2:
|
| 390 |
+
# Connect in sequence: A-B, B-C, etc.
|
| 391 |
+
for i in range(len(base_ids) - 1):
|
| 392 |
+
add_segment(base_ids[i], base_ids[i+1], phase1_segments)
|
| 393 |
+
|
| 394 |
+
# ONLY close the loop if we have 3 or more points (a real polygon)
|
| 395 |
+
if len(base_ids) > 2:
|
| 396 |
+
add_segment(base_ids[-1], base_ids[0], phase1_segments)
|
| 397 |
+
|
| 398 |
+
# Phase 2: Auxiliary segments from DSL
|
| 399 |
+
phase2_segments = []
|
| 400 |
+
for p1, p2 in segments_meta:
|
| 401 |
+
add_segment(p1, p2, phase2_segments)
|
| 402 |
+
|
| 403 |
+
drawing_phases = [
|
| 404 |
+
{
|
| 405 |
+
"phase": 1,
|
| 406 |
+
"label": "Hình cơ bản",
|
| 407 |
+
"points": base_ids,
|
| 408 |
+
"segments": phase1_segments,
|
| 409 |
+
}
|
| 410 |
+
]
|
| 411 |
+
if derived_ids or phase2_segments:
|
| 412 |
+
drawing_phases.append({
|
| 413 |
+
"phase": 2,
|
| 414 |
+
"label": "Điểm và đoạn phụ",
|
| 415 |
+
"points": derived_ids,
|
| 416 |
+
"segments": phase2_segments,
|
| 417 |
+
})
|
| 418 |
+
|
| 419 |
+
return {
|
| 420 |
+
"coordinates": coords,
|
| 421 |
+
"polygon_order": polygon_order,
|
| 422 |
+
"circles": circles_meta,
|
| 423 |
+
"lines": lines_meta,
|
| 424 |
+
"rays": rays_meta,
|
| 425 |
+
"drawing_phases": drawing_phases,
|
| 426 |
+
}
|
solver/models.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import List, Dict, Union, Optional
|
| 3 |
+
|
| 4 |
+
class Point(BaseModel):
|
| 5 |
+
id: str
|
| 6 |
+
x: Optional[float] = None
|
| 7 |
+
y: Optional[float] = None
|
| 8 |
+
z: Optional[float] = None
|
| 9 |
+
|
| 10 |
+
class Constraint(BaseModel):
|
| 11 |
+
type: str # 'length', 'angle', 'parallel', etc.
|
| 12 |
+
targets: List[str]
|
| 13 |
+
value: Union[float, str]
|
tests/test_3d_solver.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from solver.dsl_parser import DSLParser
|
| 3 |
+
from solver.engine import GeometryEngine
|
| 4 |
+
from solver.models import Point, Constraint
|
| 5 |
+
|
| 6 |
+
def test_solve_square_pyramid():
|
| 7 |
+
"""
|
| 8 |
+
Test solving for a square pyramid S.ABCD.
|
| 9 |
+
Base ABCD is a square with side 10.
|
| 10 |
+
Height SO = 15, where O is the center of ABCD.
|
| 11 |
+
"""
|
| 12 |
+
dsl = """
|
| 13 |
+
POINT(A, 0, 0, 0)
|
| 14 |
+
POINT(B, 10, 0, 0)
|
| 15 |
+
POINT(C, 10, 10, 0)
|
| 16 |
+
POINT(D, 0, 10, 0)
|
| 17 |
+
POINT(S)
|
| 18 |
+
POINT(O)
|
| 19 |
+
MIDPOINT(M1, AB)
|
| 20 |
+
MIDPOINT(M2, AC)
|
| 21 |
+
SECTION(O, A, C, 0.5)
|
| 22 |
+
LENGTH(SO, 15)
|
| 23 |
+
PERPENDICULAR(SO, AC)
|
| 24 |
+
PERPENDICULAR(SO, AB)
|
| 25 |
+
PYRAMID(S_ABCD)
|
| 26 |
+
"""
|
| 27 |
+
parser = DSLParser()
|
| 28 |
+
engine = GeometryEngine()
|
| 29 |
+
|
| 30 |
+
points, constraints = parser.parse(dsl)
|
| 31 |
+
result = engine.solve(points, constraints)
|
| 32 |
+
|
| 33 |
+
assert result is not None
|
| 34 |
+
coords = result["coordinates"]
|
| 35 |
+
|
| 36 |
+
# Check base points
|
| 37 |
+
assert coords["A"] == [0.0, 0.0, 0.0]
|
| 38 |
+
assert coords["B"] == [10.0, 0.0, 0.0]
|
| 39 |
+
assert coords["C"] == [10.0, 10.0, 0.0]
|
| 40 |
+
assert coords["D"] == [0.0, 10.0, 0.0]
|
| 41 |
+
|
| 42 |
+
# Check center O (should be (5, 5, 0))
|
| 43 |
+
assert coords["O"][0] == pytest.approx(5.0)
|
| 44 |
+
assert coords["O"][1] == pytest.approx(5.0)
|
| 45 |
+
assert coords["O"][2] == pytest.approx(0.0)
|
| 46 |
+
|
| 47 |
+
# Check apex S (should be (5, 5, 15) or (5, 5, -15))
|
| 48 |
+
assert coords["S"][0] == pytest.approx(5.0)
|
| 49 |
+
assert coords["S"][1] == pytest.approx(5.0)
|
| 50 |
+
assert abs(coords["S"][2]) == pytest.approx(15.0)
|
| 51 |
+
|
| 52 |
+
def test_solve_prism():
|
| 53 |
+
"""
|
| 54 |
+
Triangular prism ABC_DEF.
|
| 55 |
+
Base ABC is right triangle at A. AB=3, AC=4.
|
| 56 |
+
Height AD=10.
|
| 57 |
+
"""
|
| 58 |
+
dsl = """
|
| 59 |
+
POINT(A, 0, 0, 0)
|
| 60 |
+
POINT(B, 3, 0, 0)
|
| 61 |
+
POINT(C, 0, 4, 0)
|
| 62 |
+
POINT(D)
|
| 63 |
+
POINT(E)
|
| 64 |
+
POINT(F)
|
| 65 |
+
LENGTH(AD, 10)
|
| 66 |
+
PERPENDICULAR(AD, AB)
|
| 67 |
+
PERPENDICULAR(AD, AC)
|
| 68 |
+
PRISM(ABC_DEF)
|
| 69 |
+
"""
|
| 70 |
+
parser = DSLParser()
|
| 71 |
+
engine = GeometryEngine()
|
| 72 |
+
|
| 73 |
+
points, constraints = parser.parse(dsl)
|
| 74 |
+
result = engine.solve(points, constraints)
|
| 75 |
+
|
| 76 |
+
assert result is not None
|
| 77 |
+
coords = result["coordinates"]
|
| 78 |
+
|
| 79 |
+
# D should be (0, 0, 10)
|
| 80 |
+
assert coords["D"][0] == pytest.approx(0.0)
|
| 81 |
+
assert coords["D"][1] == pytest.approx(0.0)
|
| 82 |
+
assert abs(coords["D"][2]) == pytest.approx(10.0)
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
pytest.main([__file__])
|
tests/test_advanced_geometry.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import asyncio
|
| 3 |
+
import logging
|
| 4 |
+
from solver.dsl_parser import DSLParser
|
| 5 |
+
from solver.engine import GeometryEngine
|
| 6 |
+
|
| 7 |
+
logging.basicConfig(level=logging.DEBUG)
|
| 8 |
+
|
| 9 |
+
@pytest.mark.asyncio
|
| 10 |
+
async def test_section_internal():
|
| 11 |
+
print("\n--- Test: Section Point (Internal AE=2/3 AC) ---")
|
| 12 |
+
dsl = """
|
| 13 |
+
POINT(A)
|
| 14 |
+
POINT(B)
|
| 15 |
+
POINT(C)
|
| 16 |
+
LENGTH(AB, 6)
|
| 17 |
+
LENGTH(BC, 6)
|
| 18 |
+
ANGLE(B, 90)
|
| 19 |
+
SECTION(E, A, C, 0.6667)
|
| 20 |
+
"""
|
| 21 |
+
parser = DSLParser()
|
| 22 |
+
engine = GeometryEngine()
|
| 23 |
+
|
| 24 |
+
pts, constraints = parser.parse(dsl)
|
| 25 |
+
result = engine.solve(pts, constraints)
|
| 26 |
+
|
| 27 |
+
if result:
|
| 28 |
+
coords = result['coordinates']
|
| 29 |
+
print(f" A: {coords['A']}")
|
| 30 |
+
print(f" C: {coords['C']}")
|
| 31 |
+
print(f" E: {coords['E']}")
|
| 32 |
+
|
| 33 |
+
# Verify AE = 0.6667 * AC
|
| 34 |
+
import math
|
| 35 |
+
def dist(p1, p2): return math.sqrt((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2)
|
| 36 |
+
|
| 37 |
+
d_ac = dist(coords['A'], coords['C'])
|
| 38 |
+
d_ae = dist(coords['A'], coords['E'])
|
| 39 |
+
ratio = d_ae / d_ac
|
| 40 |
+
print(f" Calculated Ratio AE/AC: {ratio:.4f} (Expected: 0.6667)")
|
| 41 |
+
assert abs(ratio - 0.6667) < 1e-4
|
| 42 |
+
else:
|
| 43 |
+
print(" ❌ Solve failed")
|
| 44 |
+
|
| 45 |
+
@pytest.mark.asyncio
|
| 46 |
+
async def test_section_external():
|
| 47 |
+
print("\n--- Test: Section Point (External AE=2*AC) ---")
|
| 48 |
+
dsl = """
|
| 49 |
+
POINT(A)
|
| 50 |
+
POINT(C)
|
| 51 |
+
LENGTH(AC, 5)
|
| 52 |
+
SECTION(E, A, C, 2.0)
|
| 53 |
+
"""
|
| 54 |
+
parser = DSLParser()
|
| 55 |
+
engine = GeometryEngine()
|
| 56 |
+
|
| 57 |
+
pts, constraints = parser.parse(dsl)
|
| 58 |
+
result = engine.solve(pts, constraints)
|
| 59 |
+
|
| 60 |
+
if result:
|
| 61 |
+
coords = result['coordinates']
|
| 62 |
+
print(f" A: {coords['A']}")
|
| 63 |
+
print(f" C: {coords['C']}")
|
| 64 |
+
print(f" E: {coords['E']}")
|
| 65 |
+
|
| 66 |
+
import math
|
| 67 |
+
def dist(p1, p2): return math.sqrt((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2)
|
| 68 |
+
d_ac = dist(coords['A'], coords['C'])
|
| 69 |
+
d_ae = dist(coords['A'], coords['E'])
|
| 70 |
+
print(f" AE: {d_ae}, AC: {d_ac}, Ratio: {d_ae/d_ac}")
|
| 71 |
+
assert abs(d_ae/d_ac - 2.0) < 1e-4
|
| 72 |
+
else:
|
| 73 |
+
print(" ❌ Solve failed")
|
| 74 |
+
|
| 75 |
+
@pytest.mark.asyncio
|
| 76 |
+
async def test_line_ray_metadata():
|
| 77 |
+
print("\n--- Test: Line and Ray Metadata ---")
|
| 78 |
+
dsl = """
|
| 79 |
+
POINT(A)
|
| 80 |
+
POINT(B)
|
| 81 |
+
LINE(A, B)
|
| 82 |
+
RAY(A, B)
|
| 83 |
+
"""
|
| 84 |
+
parser = DSLParser()
|
| 85 |
+
engine = GeometryEngine()
|
| 86 |
+
|
| 87 |
+
pts, constraints = parser.parse(dsl)
|
| 88 |
+
result = engine.solve(pts, constraints)
|
| 89 |
+
|
| 90 |
+
if result:
|
| 91 |
+
print(f" Lines: {result.get('lines')}")
|
| 92 |
+
print(f" Rays: {result.get('rays')}")
|
| 93 |
+
assert ['A', 'B'] in result.get('lines', [])
|
| 94 |
+
assert ['A', 'B'] in result.get('rays', [])
|
| 95 |
+
print(" ✅ Metadata present")
|
| 96 |
+
else:
|
| 97 |
+
print(" ❌ Solve failed")
|
| 98 |
+
|
| 99 |
+
if __name__ == "__main__":
|
| 100 |
+
asyncio.run(test_section_internal())
|
| 101 |
+
asyncio.run(test_section_external())
|
| 102 |
+
asyncio.run(test_line_ray_metadata())
|
tests/test_api_full_suite.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import httpx
|
| 3 |
+
import time
|
| 4 |
+
import asyncio
|
| 5 |
+
import pytest
|
| 6 |
+
import logging
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
# Configuration
|
| 10 |
+
BASE_URL = os.getenv("TEST_BASE_URL", "http://localhost:8000")
|
| 11 |
+
USER_ID = os.getenv("TEST_USER_ID")
|
| 12 |
+
SESSION_ID = os.getenv("TEST_SESSION_ID")
|
| 13 |
+
|
| 14 |
+
logging.basicConfig(level=logging.INFO)
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
QUERIES = [
|
| 18 |
+
{
|
| 19 |
+
"id": "Q1",
|
| 20 |
+
"text": "Cho hình chữ nhật ABCD có AB bằng 5 và AD bằng 10",
|
| 21 |
+
"expect_pts": ["A", "B", "C", "D"],
|
| 22 |
+
"expect_phases": 1,
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"id": "Q2",
|
| 26 |
+
"text": "Tam giác ABC có AB=6, BC=8, AC=10",
|
| 27 |
+
"expect_pts": ["A", "B", "C"],
|
| 28 |
+
"expect_phases": 1,
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"id": "Q3",
|
| 32 |
+
"text": "Cho hình chữ nhật ABCD có AB bằng 10 và AD bằng 20. Vẽ điểm M là trung điểm của AB và N là trung điểm của AD.",
|
| 33 |
+
"expect_pts": ["A", "B", "C", "D", "M", "N"],
|
| 34 |
+
"expect_phases": 2,
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"id": "Q4",
|
| 38 |
+
"text": "Cho hình thang ABCD vuông tại A và D. AB=4, CD=8, AD=5.",
|
| 39 |
+
"expect_pts": ["A", "B", "C", "D"],
|
| 40 |
+
"expect_phases": 1,
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"id": "Q5",
|
| 44 |
+
"text": "Cho hình vuông ABCD có cạnh bằng 6.",
|
| 45 |
+
"expect_pts": ["A", "B", "C", "D"],
|
| 46 |
+
"expect_phases": 1,
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"id": "Q6",
|
| 50 |
+
"text": "Cho tam giác ABC vuông tại A. AB=3, AC=4. Vẽ đường cao AH.",
|
| 51 |
+
"expect_pts": ["A", "B", "C", "H"],
|
| 52 |
+
"expect_phases": 2,
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"id": "Q7",
|
| 56 |
+
"text": "Cho hình thoi ABCD có cạnh bằng 5 và góc A bằng 60 độ.",
|
| 57 |
+
"expect_pts": ["A", "B", "C", "D"],
|
| 58 |
+
"expect_phases": 1,
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"id": "Q8",
|
| 62 |
+
"text": "Cho đường tròn tâm O bán kính bằng 7.",
|
| 63 |
+
"expect_pts": ["O"],
|
| 64 |
+
"expect_phases": 1,
|
| 65 |
+
},
|
| 66 |
+
{
|
| 67 |
+
"id": "Q9",
|
| 68 |
+
"text": "Cho hình bình hành ABCD có AB=8, AD=6. Gọi E là trung điểm của CD. Vẽ đoạn thẳng AE.",
|
| 69 |
+
"expect_pts": ["A", "B", "C", "D", "E"],
|
| 70 |
+
"expect_phases": 2,
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"id": "Q10-Step1",
|
| 74 |
+
"text": "Cho hình chữ nhật ABCD có AB=10, AD=5.",
|
| 75 |
+
"expect_pts": ["A", "B", "C", "D"],
|
| 76 |
+
"expect_phases": 1,
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
"id": "Q11-Video",
|
| 80 |
+
"text": "Cho tam giác ABC đều cạnh 5. Vẽ đường tròn ngoại tiếp tam giác.",
|
| 81 |
+
"expect_pts": ["A", "B", "C"],
|
| 82 |
+
"expect_phases": 2,
|
| 83 |
+
"request_video": True
|
| 84 |
+
},
|
| 85 |
+
{
|
| 86 |
+
"id": "Q12-3D",
|
| 87 |
+
"text": "Cho hình chóp S.ABCD có đáy ABCD là hình vuông cạnh 10, đường cao SO=15 với O là tâm đáy.",
|
| 88 |
+
"expect_pts": ["S", "A", "B", "C", "D", "O"],
|
| 89 |
+
"expect_phases": 2,
|
| 90 |
+
}
|
| 91 |
+
]
|
| 92 |
+
|
| 93 |
+
Q10_FOLLOW_UP = {
|
| 94 |
+
"id": "Q10-Step2",
|
| 95 |
+
"text": "Vẽ thêm đường chéo AC.",
|
| 96 |
+
"expect_pts": ["A", "B", "C", "D"],
|
| 97 |
+
"expect_phases": 2,
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
test_stats = []
|
| 101 |
+
|
| 102 |
+
async def run_single_api_query(client, q, headers):
|
| 103 |
+
print(f"\n🚀 [RUNNING] {q['id']}: {q['text']}")
|
| 104 |
+
start_time = time.time()
|
| 105 |
+
|
| 106 |
+
# 1. Submit Request
|
| 107 |
+
payload = {
|
| 108 |
+
"text": q["text"],
|
| 109 |
+
"request_video": q.get("request_video", False)
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
try:
|
| 113 |
+
if q.get("isolate", True):
|
| 114 |
+
# Create a fresh session for isolation
|
| 115 |
+
session_resp = await client.post("/api/v1/sessions", headers=headers)
|
| 116 |
+
if session_resp.status_code != 200:
|
| 117 |
+
return {"id": q["id"], "query": q["text"], "success": False, "error": f"Session creation failed: {session_resp.text}"}
|
| 118 |
+
session_id = session_resp.json()["id"]
|
| 119 |
+
else:
|
| 120 |
+
session_id = q.get("session_id", SESSION_ID)
|
| 121 |
+
|
| 122 |
+
res = await client.post(f"/api/v1/sessions/{session_id}/solve", json=payload, headers=headers)
|
| 123 |
+
if res.status_code != 200:
|
| 124 |
+
print(f" ❌ FAILED: Status {res.status_code} - {res.text}")
|
| 125 |
+
return {"id": q["id"], "query": q["text"], "success": False, "error": f"HTTP {res.status_code}: {res.text}"}
|
| 126 |
+
|
| 127 |
+
job_id = res.json()["job_id"]
|
| 128 |
+
print(f" ✅ Job Created: {job_id}")
|
| 129 |
+
|
| 130 |
+
# 2. Polling result
|
| 131 |
+
max_attempts = 45 # Increased for video rendering
|
| 132 |
+
result_data = None
|
| 133 |
+
for i in range(max_attempts):
|
| 134 |
+
await asyncio.sleep(4)
|
| 135 |
+
res = await client.get(f"/api/v1/solve/{job_id}", headers=headers)
|
| 136 |
+
data = res.json()
|
| 137 |
+
status = data.get("status")
|
| 138 |
+
print(f" - Polling ({i+1}): {status}")
|
| 139 |
+
|
| 140 |
+
if status == "success":
|
| 141 |
+
result_data = data["result"]
|
| 142 |
+
break
|
| 143 |
+
if status == "error":
|
| 144 |
+
print(f" ❌ ERROR: {data.get('result', {}).get('error')}")
|
| 145 |
+
return {"id": q["id"], "query": q["text"], "success": False, "error": data.get("result", {}).get("error")}
|
| 146 |
+
|
| 147 |
+
if i == max_attempts - 1:
|
| 148 |
+
print(" ❌ TIMEOUT")
|
| 149 |
+
return {"id": q["id"], "query": q["text"], "success": False, "error": "Timeout"}
|
| 150 |
+
|
| 151 |
+
# 3. Strict Validation
|
| 152 |
+
elapsed = time.time() - start_time
|
| 153 |
+
errors = []
|
| 154 |
+
|
| 155 |
+
# Validation: Coordinates
|
| 156 |
+
coords = result_data.get("coordinates", {})
|
| 157 |
+
for pt in q["expect_pts"]:
|
| 158 |
+
if pt not in coords:
|
| 159 |
+
errors.append(f"Missing point {pt}")
|
| 160 |
+
|
| 161 |
+
# Validation: Non-zero coords (generic check)
|
| 162 |
+
# Only fail if there are MULTIPLE points and all are at origin.
|
| 163 |
+
# A single point (like a circle center) at origin is perfectly valid.
|
| 164 |
+
if coords and len(coords) > 1 and all(v == [0,0,0] for v in coords.values()):
|
| 165 |
+
errors.append("All points are at [0,0,0]")
|
| 166 |
+
|
| 167 |
+
# Validation: Drawing Phases
|
| 168 |
+
phases = result_data.get("drawing_phases", [])
|
| 169 |
+
if len(phases) < q["expect_phases"]:
|
| 170 |
+
errors.append(f"Expected {q['expect_phases']} phases, got {len(phases)}")
|
| 171 |
+
|
| 172 |
+
# Validation: Video URL if requested
|
| 173 |
+
if q.get("request_video") and not result_data.get("video_url"):
|
| 174 |
+
# We allow video fail if it's environment issue, but log it
|
| 175 |
+
print(" ⚠️ Video requested but no URL found (Expected in some test envs)")
|
| 176 |
+
# errors.append("Video URL missing")
|
| 177 |
+
|
| 178 |
+
if errors:
|
| 179 |
+
print(f" ❌ VALIDATION FAILED: {', '.join(errors)}")
|
| 180 |
+
return {"id": q["id"], "query": q["text"], "success": False, "error": "; ".join(errors), "elapsed": elapsed, "result": result_data}
|
| 181 |
+
|
| 182 |
+
print(f" ✅ PASS ({elapsed:.2f}s)")
|
| 183 |
+
return {"id": q['id'], "query": q["text"], "success": True, "elapsed": elapsed, "job_id": job_id, "result": result_data}
|
| 184 |
+
|
| 185 |
+
except Exception as e:
|
| 186 |
+
print(f" ❌ EXCEPTION: {str(e)}")
|
| 187 |
+
return {"id": q["id"], "query": q["text"], "success": False, "error": str(e)}
|
| 188 |
+
|
| 189 |
+
@pytest.mark.asyncio
|
| 190 |
+
async def test_full_api_suite():
|
| 191 |
+
if not USER_ID or not SESSION_ID:
|
| 192 |
+
pytest.fail("TEST_USER_ID and TEST_SESSION_ID must be set")
|
| 193 |
+
|
| 194 |
+
headers = {"Authorization": f"Test {USER_ID}"}
|
| 195 |
+
|
| 196 |
+
async with httpx.AsyncClient(base_url=BASE_URL, timeout=60.0) as client:
|
| 197 |
+
# Run standard queries
|
| 198 |
+
import uuid
|
| 199 |
+
for q in QUERIES:
|
| 200 |
+
if q["id"] == "Q10-Step1": continue
|
| 201 |
+
# Isolated by default
|
| 202 |
+
res = await run_single_api_query(client, q, headers)
|
| 203 |
+
test_stats.append(res)
|
| 204 |
+
|
| 205 |
+
# Run Multi-turn Q10
|
| 206 |
+
print("\n--- Testing Multi-turn API Flow (Q10) ---")
|
| 207 |
+
# Create a shared session for Q10
|
| 208 |
+
shared_session_resp = await client.post("/api/v1/sessions", headers=headers)
|
| 209 |
+
shared_session = shared_session_resp.json()["id"]
|
| 210 |
+
|
| 211 |
+
q10_1 = next(q for q in QUERIES if q["id"] == "Q10-Step1")
|
| 212 |
+
q10_1["session_id"] = shared_session
|
| 213 |
+
q10_1["isolate"] = False
|
| 214 |
+
res10_1 = await run_single_api_query(client, q10_1, headers)
|
| 215 |
+
test_stats.append(res10_1)
|
| 216 |
+
|
| 217 |
+
if res10_1["success"]:
|
| 218 |
+
Q10_FOLLOW_UP["session_id"] = shared_session
|
| 219 |
+
Q10_FOLLOW_UP["isolate"] = False
|
| 220 |
+
res10_2 = await run_single_api_query(client, Q10_FOLLOW_UP, headers)
|
| 221 |
+
|
| 222 |
+
# Additional check for Q10-Step2: check if DSL contains combined logic
|
| 223 |
+
if res10_2["success"]:
|
| 224 |
+
dsl = res10_2["result"].get("geometry_dsl", "")
|
| 225 |
+
if "POLYGON_ORDER" not in dsl or "SEGMENT" not in dsl:
|
| 226 |
+
res10_2["success"] = False
|
| 227 |
+
res10_2["error"] = "DSL did not merge history correctly"
|
| 228 |
+
|
| 229 |
+
test_stats.append(res10_2)
|
| 230 |
+
|
| 231 |
+
# Save Results to JSON for the runner script to generate Markdown
|
| 232 |
+
with open("temp_suite_results.json", "w") as f:
|
| 233 |
+
json.dump(test_stats, f)
|
| 234 |
+
|
| 235 |
+
if __name__ == "__main__":
|
| 236 |
+
import asyncio
|
| 237 |
+
asyncio.run(test_full_api_suite())
|
tests/test_api_metadata_real.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import asyncio
|
| 3 |
+
import uuid
|
| 4 |
+
import time
|
| 5 |
+
from app.routers.solve import process_session_job
|
| 6 |
+
from app.models.schemas import SolveRequest
|
| 7 |
+
from app.supabase_client import get_supabase
|
| 8 |
+
|
| 9 |
+
@pytest.mark.asyncio
|
| 10 |
+
async def test_metadata_persistence():
|
| 11 |
+
session_id = "81f87517-88f2-40bd-96a9-7b34f1d14b6a"
|
| 12 |
+
user_id = "8cd3adb0-7964-4575-949c-d0cadcd8b679"
|
| 13 |
+
job_id = str(uuid.uuid4())
|
| 14 |
+
|
| 15 |
+
print(f"🚀 Starting sub-pipeline test for job {job_id}...")
|
| 16 |
+
|
| 17 |
+
request = SolveRequest(
|
| 18 |
+
text="Cho hình chữ nhật ABCD có AB=10, AD=20. Vẽ đường thẳng d đi qua A và B.",
|
| 19 |
+
request_video=False
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Trigger the process_session_job directly
|
| 23 |
+
await process_session_job(job_id, session_id, request, user_id)
|
| 24 |
+
|
| 25 |
+
print("⏳ Waiting for database sync (3s)...")
|
| 26 |
+
await asyncio.sleep(3)
|
| 27 |
+
|
| 28 |
+
# Verify the results in Supabase
|
| 29 |
+
supabase = get_supabase()
|
| 30 |
+
res = supabase.table("messages") \
|
| 31 |
+
.select("metadata, created_at") \
|
| 32 |
+
.eq("session_id", session_id) \
|
| 33 |
+
.eq("role", "assistant") \
|
| 34 |
+
.order("created_at", desc=True) \
|
| 35 |
+
.limit(1) \
|
| 36 |
+
.execute()
|
| 37 |
+
|
| 38 |
+
if not res.data:
|
| 39 |
+
print("❌ FAIL: No assistant message found in database.")
|
| 40 |
+
return
|
| 41 |
+
|
| 42 |
+
metadata = res.data[0].get("metadata", {})
|
| 43 |
+
required_fields = ["job_id", "coordinates", "polygon_order", "drawing_phases", "circles", "lines", "rays"]
|
| 44 |
+
missing = [f for f in required_fields if f not in metadata]
|
| 45 |
+
|
| 46 |
+
if not missing:
|
| 47 |
+
print("✅ SUCCESS: All metadata fields (including lines/rays) persisted correctly.")
|
| 48 |
+
print(f" job_id: {metadata.get('job_id')}")
|
| 49 |
+
print(f" polygon_order: {metadata.get('polygon_order')}")
|
| 50 |
+
print(f" lines: {metadata.get('lines')}")
|
| 51 |
+
print(f" phases: {len(metadata.get('drawing_phases', []))}")
|
| 52 |
+
else:
|
| 53 |
+
print(f"❌ FAIL: Missing fields in metadata: {missing}")
|
| 54 |
+
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
asyncio.run(test_metadata_persistence())
|
tests/test_api_real_e2e.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import httpx
|
| 3 |
+
import time
|
| 4 |
+
import pytest
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
# Configuration from environment
|
| 8 |
+
BASE_URL = os.getenv("TEST_BASE_URL", "http://localhost:8000")
|
| 9 |
+
USER_ID = os.getenv("TEST_USER_ID")
|
| 10 |
+
SESSION_ID = os.getenv("TEST_SESSION_ID")
|
| 11 |
+
|
| 12 |
+
logging.basicConfig(level=logging.INFO)
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
@pytest.mark.asyncio
|
| 16 |
+
async def test_api_e2e_flow():
|
| 17 |
+
if not USER_ID or not SESSION_ID:
|
| 18 |
+
pytest.fail("TEST_USER_ID and TEST_SESSION_ID must be set")
|
| 19 |
+
|
| 20 |
+
auth_headers = {"Authorization": f"Test {USER_ID}"}
|
| 21 |
+
|
| 22 |
+
async with httpx.AsyncClient(base_url=BASE_URL, timeout=30.0) as client:
|
| 23 |
+
# 1. Health check
|
| 24 |
+
print("\n[1/3] Checking API Health...")
|
| 25 |
+
res = await client.get("/")
|
| 26 |
+
assert res.status_code == 200
|
| 27 |
+
assert "running" in res.json()["message"].lower()
|
| 28 |
+
print(" ✅ Health check passed")
|
| 29 |
+
|
| 30 |
+
# 2. Submit Solve Request
|
| 31 |
+
print(f"\n[2/3] Submitting solve request for session {SESSION_ID}...")
|
| 32 |
+
payload = {
|
| 33 |
+
"text": "Cho hình chữ nhật ABCD có AB=5, AD=10. Tính diện tích.",
|
| 34 |
+
"request_video": False
|
| 35 |
+
}
|
| 36 |
+
res = await client.post(f"/api/v1/sessions/{SESSION_ID}/solve", json=payload, headers=auth_headers)
|
| 37 |
+
|
| 38 |
+
if res.status_code != 200:
|
| 39 |
+
print(f" ❌ FAILED: {res.text}")
|
| 40 |
+
assert res.status_code == 200
|
| 41 |
+
|
| 42 |
+
data = res.json()
|
| 43 |
+
job_id = data["job_id"]
|
| 44 |
+
assert job_id is not None
|
| 45 |
+
print(f" ✅ Request accepted. Job ID: {job_id}")
|
| 46 |
+
|
| 47 |
+
# 3. Polling Job Status
|
| 48 |
+
print("\n[3/3] Polling job status...")
|
| 49 |
+
max_attempts = 15
|
| 50 |
+
for i in range(max_attempts):
|
| 51 |
+
time.sleep(2) # Simple sleep between polls
|
| 52 |
+
res = await client.get(f"/api/v1/solve/{job_id}")
|
| 53 |
+
assert res.status_code == 200
|
| 54 |
+
job_data = res.json()
|
| 55 |
+
status = job_data["status"]
|
| 56 |
+
print(f" Attempt {i+1}: Status = {status}")
|
| 57 |
+
|
| 58 |
+
if status == "success":
|
| 59 |
+
print(" ✅ SUCCESS: API pipeline completed successfully.")
|
| 60 |
+
result = job_data.get("result", {})
|
| 61 |
+
assert "coordinates" in result
|
| 62 |
+
assert "geometry_dsl" in result
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
+
if status == "error":
|
| 66 |
+
error_msg = job_data.get("result", {}).get("error", "Unknown error")
|
| 67 |
+
pytest.fail(f"Job failed with error: {error_msg}")
|
| 68 |
+
|
| 69 |
+
if i == max_attempts - 1:
|
| 70 |
+
pytest.fail("Timeout waiting for job completion")
|
| 71 |
+
|
| 72 |
+
if __name__ == "__main__":
|
| 73 |
+
# This allows running the script directly if needed
|
| 74 |
+
import asyncio
|
| 75 |
+
asyncio.run(test_api_e2e_flow())
|
tests/test_direct_task.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
|
| 6 |
+
# Ensure we can import from backend
|
| 7 |
+
sys.path.append(os.getcwd())
|
| 8 |
+
|
| 9 |
+
from app.supabase_client import get_supabase
|
| 10 |
+
from worker.tasks import render_geometry_video
|
| 11 |
+
|
| 12 |
+
def test_celery_task_directly():
|
| 13 |
+
load_dotenv()
|
| 14 |
+
|
| 15 |
+
# Mock data for a square
|
| 16 |
+
data = {
|
| 17 |
+
"session_id": "88888888-8888-8888-8888-888888888888", # Fake uuid
|
| 18 |
+
"coordinates": {
|
| 19 |
+
"A": [0, 0],
|
| 20 |
+
"B": [5, 0],
|
| 21 |
+
"C": [5, 5],
|
| 22 |
+
"D": [0, 5]
|
| 23 |
+
},
|
| 24 |
+
"polygon_order": ["A", "B", "C", "D"],
|
| 25 |
+
"drawing_phases": [
|
| 26 |
+
{
|
| 27 |
+
"phase": 1,
|
| 28 |
+
"label": "Base",
|
| 29 |
+
"points": ["A", "B", "C", "D"],
|
| 30 |
+
"segments": [["A","B"],["B","C"],["C","D"],["D","A"]]
|
| 31 |
+
}
|
| 32 |
+
],
|
| 33 |
+
"semantic_analysis": "Test squere video rendering."
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
job_id = f"manual-direct-test-{int(os.time.time()) if hasattr(os, 'time') else 123}"
|
| 37 |
+
# Just use a static ID or similar
|
| 38 |
+
import time
|
| 39 |
+
job_id = f"manual-test-{int(time.time())}"
|
| 40 |
+
|
| 41 |
+
print(f"🚀 Running render_geometry_video directly for job {job_id}...")
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
# We need to mock Supabase calls if we don't want to actually hit the DB,
|
| 45 |
+
# but here we WANT to test the real task logic.
|
| 46 |
+
# This will fail on DB update if job_id doesn't exist in 'jobs' table.
|
| 47 |
+
# So let's create a dummy job first.
|
| 48 |
+
supabase = get_supabase()
|
| 49 |
+
supabase.table("jobs").insert({
|
| 50 |
+
"id": job_id,
|
| 51 |
+
"user_id": None,
|
| 52 |
+
"status": "processing",
|
| 53 |
+
"type": "solve"
|
| 54 |
+
}).execute()
|
| 55 |
+
|
| 56 |
+
# Run the task function directly (not via .delay)
|
| 57 |
+
video_url = render_geometry_video(job_id, data)
|
| 58 |
+
|
| 59 |
+
if video_url:
|
| 60 |
+
print(f"✅ SUCCESS! Video URL: {video_url}")
|
| 61 |
+
else:
|
| 62 |
+
print("❌ FAIL: No video URL returned.")
|
| 63 |
+
|
| 64 |
+
except NameError as e:
|
| 65 |
+
print(f"❌ NameError Caught: {e}")
|
| 66 |
+
except Exception as e:
|
| 67 |
+
print(f"❌ Error during manual task execution: {e}")
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
test_celery_task_directly()
|
tests/test_full_pipeline.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import logging
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import math
|
| 6 |
+
import time
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
from app.logging_setup import setup_application_logging
|
| 10 |
+
setup_application_logging()
|
| 11 |
+
logging.getLogger("agents").setLevel(logging.DEBUG)
|
| 12 |
+
logging.getLogger("solver").setLevel(logging.DEBUG)
|
| 13 |
+
logging.getLogger("app").setLevel(logging.DEBUG)
|
| 14 |
+
|
| 15 |
+
from agents.orchestrator import Orchestrator
|
| 16 |
+
from app.supabase_client import get_supabase
|
| 17 |
+
|
| 18 |
+
QUERIES = [
|
| 19 |
+
{
|
| 20 |
+
"id": "Q1",
|
| 21 |
+
"text": "Cho hình chữ nhật ABCD có AB bằng 5 và AD bằng 10",
|
| 22 |
+
"expect_pts": ["A", "B", "C", "D"],
|
| 23 |
+
"expect_phases": 1,
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"id": "Q2",
|
| 27 |
+
"text": "Tam giác ABC có AB=6, BC=8, AC=10",
|
| 28 |
+
"expect_pts": ["A", "B", "C"],
|
| 29 |
+
"expect_phases": 1,
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"id": "Q3",
|
| 33 |
+
"text": "Cho hình chữ nhật ABCD có AB bằng 10 và AD bằng 20. Vẽ điểm M là trung điểm của AB và N là trung điểm của AD.",
|
| 34 |
+
"expect_pts": ["A", "B", "C", "D", "M", "N"],
|
| 35 |
+
"expect_phases": 2,
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"id": "Q4",
|
| 39 |
+
"text": "Cho hình thang ABCD vuông tại A và D. AB=4, CD=8, AD=5.",
|
| 40 |
+
"expect_pts": ["A", "B", "C", "D"],
|
| 41 |
+
"expect_phases": 1,
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"id": "Q5",
|
| 45 |
+
"text": "Cho hình vuông ABCD có cạnh bằng 6.",
|
| 46 |
+
"expect_pts": ["A", "B", "C", "D"],
|
| 47 |
+
"expect_phases": 1,
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"id": "Q6",
|
| 51 |
+
"text": "Cho tam giác ABC vuông tại A. AB=3, AC=4. Vẽ đường cao AH.",
|
| 52 |
+
"expect_pts": ["A", "B", "C", "H"],
|
| 53 |
+
"expect_phases": 2,
|
| 54 |
+
},
|
| 55 |
+
{
|
| 56 |
+
"id": "Q7",
|
| 57 |
+
"text": "Cho hình thoi ABCD có cạnh bằng 5 và góc A bằng 60 độ.",
|
| 58 |
+
"expect_pts": ["A", "B", "C", "D"],
|
| 59 |
+
"expect_phases": 1,
|
| 60 |
+
},
|
| 61 |
+
{
|
| 62 |
+
"id": "Q8",
|
| 63 |
+
"text": "Cho đường tròn tâm O bán kính bằng 7.",
|
| 64 |
+
"expect_pts": ["O"],
|
| 65 |
+
"expect_phases": 1,
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"id": "Q9",
|
| 69 |
+
"text": "Cho hình bình hành ABCD có AB=8, AD=6. Gọi E là trung điểm của CD. Vẽ đoạn thẳng AE.",
|
| 70 |
+
"expect_pts": ["A", "B", "C", "D", "E"],
|
| 71 |
+
"expect_phases": 2,
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"id": "Q10-Step1",
|
| 75 |
+
"text": "Cho hình chữ nhật ABCD có AB=10, AD=5.",
|
| 76 |
+
"expect_pts": ["A", "B", "C", "D"],
|
| 77 |
+
"expect_phases": 1,
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"id": "Q11-Video",
|
| 81 |
+
"text": "Cho tam giác ABC đều cạnh 5. Vẽ đường tròn ngoại tiếp tam giác.",
|
| 82 |
+
"expect_pts": ["A", "B", "C"],
|
| 83 |
+
"expect_phases": 2,
|
| 84 |
+
"request_video": True
|
| 85 |
+
}
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
# Q10-Step2 is a follow-up to Q10-Step1
|
| 89 |
+
Q10_FOLLOW_UP = {
|
| 90 |
+
"id": "Q10-Step2",
|
| 91 |
+
"text": "Vẽ thêm đường chéo AC.",
|
| 92 |
+
"expect_pts": ["A", "B", "C", "D"],
|
| 93 |
+
"expect_phases": 2, # Main polygon + diagonal segment
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
def dist(p1, p2):
|
| 97 |
+
return math.sqrt((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2)
|
| 98 |
+
|
| 99 |
+
async def run_query(orchestrator, q, history=None):
|
| 100 |
+
print(f"\n{'='*60}")
|
| 101 |
+
print(f"[{q['id']}] {q['text']}")
|
| 102 |
+
if history:
|
| 103 |
+
print(f" (With history context of {len(history)} messages)")
|
| 104 |
+
if q.get("request_video"):
|
| 105 |
+
print(" 🎥 VIDEO RENDERING REQUESTED")
|
| 106 |
+
print('='*60)
|
| 107 |
+
try:
|
| 108 |
+
result = await orchestrator.run(
|
| 109 |
+
text=q["text"],
|
| 110 |
+
job_id=f"test-{q['id']}-{int(time.time())}",
|
| 111 |
+
request_video=q.get("request_video", False),
|
| 112 |
+
history=history,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
if "error" in result:
|
| 116 |
+
print(f" ❌ PIPELINE ERROR: {result['error']}")
|
| 117 |
+
return None
|
| 118 |
+
|
| 119 |
+
# Check 1: semantic_analysis != original query
|
| 120 |
+
analysis = result.get("semantic_analysis", "")
|
| 121 |
+
if analysis.strip() == q["text"].strip():
|
| 122 |
+
print(f" ❌ FAIL: semantic_analysis is identical to input query")
|
| 123 |
+
else:
|
| 124 |
+
print(f" ✅ semantic_analysis: {analysis[:100]}...")
|
| 125 |
+
|
| 126 |
+
# Check 2: all expected points are in coordinates
|
| 127 |
+
coords = result.get("coordinates", {})
|
| 128 |
+
missing = [pt for pt in q["expect_pts"] if pt not in coords]
|
| 129 |
+
if missing:
|
| 130 |
+
print(f" ❌ FAIL: Missing points in coordinates: {missing}")
|
| 131 |
+
else:
|
| 132 |
+
print(f" ✅ All expected points present: {list(coords.keys())}")
|
| 133 |
+
|
| 134 |
+
# Check 4: drawing_phases
|
| 135 |
+
phases = result.get("drawing_phases", [])
|
| 136 |
+
if len(phases) >= q["expect_phases"]:
|
| 137 |
+
print(f" ✅ drawing_phases: {len(phases)} phase(s)")
|
| 138 |
+
else:
|
| 139 |
+
print(f" ❌ FAIL: expected {q['expect_phases']} drawing phase(s), got {len(phases)}")
|
| 140 |
+
|
| 141 |
+
# Check 5: Video Polling (if requested)
|
| 142 |
+
if q.get("request_video"):
|
| 143 |
+
job_id = result.get("job_id")
|
| 144 |
+
if not job_id:
|
| 145 |
+
print(" ❌ FAIL: Video requested but no job_id returned")
|
| 146 |
+
return None
|
| 147 |
+
|
| 148 |
+
print(f" ⏳ Waiting for Manim video (job_id: {job_id})...")
|
| 149 |
+
supabase = get_supabase()
|
| 150 |
+
max_retries = 24 # 2 minutes (24 * 5s)
|
| 151 |
+
success = False
|
| 152 |
+
for _ in range(max_retries):
|
| 153 |
+
job_res = supabase.table("jobs").select("*").eq("id", job_id).execute()
|
| 154 |
+
if job_res.data:
|
| 155 |
+
job_data = job_res.data[0]
|
| 156 |
+
status = job_data.get("status")
|
| 157 |
+
if status == "success":
|
| 158 |
+
video_url = job_data.get("result", {}).get("video_url")
|
| 159 |
+
if video_url:
|
| 160 |
+
print(f" ✅ VIDEO READY: {video_url}")
|
| 161 |
+
result["video_url"] = video_url
|
| 162 |
+
success = True
|
| 163 |
+
break
|
| 164 |
+
else:
|
| 165 |
+
print(" ❌ FAIL: Job success but no video_url in result")
|
| 166 |
+
return None
|
| 167 |
+
elif status == "failed":
|
| 168 |
+
print(f" ❌ FAIL: Manim worker job failed")
|
| 169 |
+
return None
|
| 170 |
+
await asyncio.sleep(5)
|
| 171 |
+
|
| 172 |
+
if not success:
|
| 173 |
+
print(" ❌ FAIL: Timeout waiting for video rendering")
|
| 174 |
+
return None
|
| 175 |
+
|
| 176 |
+
dsl = result.get('geometry_dsl', '')
|
| 177 |
+
print(f" DSL ({len(dsl.splitlines())} lines):\n{dsl}")
|
| 178 |
+
|
| 179 |
+
# Specific check for Q10-Step2: must contain BOTH ABCD and AC segment
|
| 180 |
+
if q["id"] == "Q10-Step2":
|
| 181 |
+
if "POLYGON_ORDER(A, B, C, D)" in dsl and "SEGMENT(A, C)" in dsl:
|
| 182 |
+
print(f" ✅ Multi-turn Success: DSL merged correctly.")
|
| 183 |
+
else:
|
| 184 |
+
print(f" ❌ Multi-turn Fail: DSL missing component.")
|
| 185 |
+
|
| 186 |
+
return result
|
| 187 |
+
|
| 188 |
+
except Exception as e:
|
| 189 |
+
import traceback
|
| 190 |
+
print(f" ❌ EXCEPTION: {type(e).__name__}: {e}")
|
| 191 |
+
traceback.print_exc()
|
| 192 |
+
return None
|
| 193 |
+
|
| 194 |
+
async def main():
|
| 195 |
+
load_dotenv()
|
| 196 |
+
orchestrator = Orchestrator()
|
| 197 |
+
|
| 198 |
+
results = []
|
| 199 |
+
# Run Q1 to Q9 and Q11 (Video)
|
| 200 |
+
queries_to_test = QUERIES[:-1] + [QUERIES[-1]] # All except Step1
|
| 201 |
+
|
| 202 |
+
# Actually let's just iterate over all and handle Q10 special
|
| 203 |
+
for q in QUERIES:
|
| 204 |
+
if q["id"] == "Q10-Step1":
|
| 205 |
+
continue
|
| 206 |
+
res = await run_query(orchestrator, q)
|
| 207 |
+
results.append((q["id"], res is not None))
|
| 208 |
+
|
| 209 |
+
# Run Q10 Flow (Multi-turn)
|
| 210 |
+
print("\n--- Starting Multi-turn Flow (Q10) ---")
|
| 211 |
+
q10_1 = next(q for q in QUERIES if q["id"] == "Q10-Step1")
|
| 212 |
+
res10_1 = await run_query(orchestrator, q10_1)
|
| 213 |
+
results.append((q10_1["id"], res10_1 is not None))
|
| 214 |
+
|
| 215 |
+
if res10_1:
|
| 216 |
+
# Construct message history to pass to step 2
|
| 217 |
+
history = [
|
| 218 |
+
{"role": "user", "content": q10_1["text"]},
|
| 219 |
+
{
|
| 220 |
+
"role": "assistant",
|
| 221 |
+
"content": res10_1["semantic_analysis"],
|
| 222 |
+
"metadata": {
|
| 223 |
+
"geometry_dsl": res10_1["geometry_dsl"],
|
| 224 |
+
"coordinates": res10_1["coordinates"]
|
| 225 |
+
}
|
| 226 |
+
}
|
| 227 |
+
]
|
| 228 |
+
res10_2 = await run_query(orchestrator, Q10_FOLLOW_UP, history=history)
|
| 229 |
+
results.append((Q10_FOLLOW_UP["id"], res10_2 is not None))
|
| 230 |
+
|
| 231 |
+
print(f"\n{'='*60}")
|
| 232 |
+
print("SUMMARY:")
|
| 233 |
+
for qid, ok in results:
|
| 234 |
+
print(f" [{qid}] {'✅ PASS' if ok else '❌ FAIL'}")
|
| 235 |
+
|
| 236 |
+
if __name__ == "__main__":
|
| 237 |
+
asyncio.run(main())
|
tests/test_openrouter.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import httpx
|
| 3 |
+
import json
|
| 4 |
+
import time
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
|
| 7 |
+
# Load môi trường từ backend/.env
|
| 8 |
+
load_dotenv(dotenv_path="./backend/.env")
|
| 9 |
+
|
| 10 |
+
MODELS = [
|
| 11 |
+
"nvidia/nemotron-3-super-120b-a12b:free",
|
| 12 |
+
"meta-llama/llama-3.3-70b-instruct:free",
|
| 13 |
+
"openai/gpt-oss-120b:free",
|
| 14 |
+
"z-ai/glm-4.5-air:free",
|
| 15 |
+
"minimax/minimax-m2.5:free",
|
| 16 |
+
"google/gemma-4-26b-a4b-it:free",
|
| 17 |
+
"google/gemma-4-31b-it:free",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
PROMPT = "Cho hình chữ nhật ABCD có AB bằng 5 và AD bằng 10. Gọi E là điểm nằm trong đoạn CD sao cho CE = 2ED. Vẽ đoạn thẳng AE. Vẽ thêm P là điểm nằm trên đường thẳng BC sao cho BP = 2PC, tính chu vi tam giác PEA"
|
| 21 |
+
|
| 22 |
+
def test_models():
|
| 23 |
+
api_key = os.getenv("OPENROUTER_API_KEY_1")
|
| 24 |
+
base_url = "https://openrouter.ai/api/v1/chat/completions"
|
| 25 |
+
|
| 26 |
+
if not api_key:
|
| 27 |
+
print("❌ Lỗi: Không tìm thấy OPENROUTER_API_KEY trong file .env")
|
| 28 |
+
return
|
| 29 |
+
|
| 30 |
+
print("🚀 Bắt đầu benchmark các model OpenRouter...")
|
| 31 |
+
print(f"📝 Prompt: {PROMPT}\n")
|
| 32 |
+
|
| 33 |
+
results = []
|
| 34 |
+
|
| 35 |
+
for model in MODELS:
|
| 36 |
+
print(f"📡 Đang gọi model: {model}...", end="", flush=True)
|
| 37 |
+
|
| 38 |
+
headers = {
|
| 39 |
+
"Authorization": f"Bearer {api_key}",
|
| 40 |
+
"Content-Type": "application/json",
|
| 41 |
+
"HTTP-Referer": "https://mathsolver.io",
|
| 42 |
+
"X-Title": "MathSolver Benchmark Tool"
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
payload = {
|
| 46 |
+
"model": model,
|
| 47 |
+
"messages": [{"role": "user", "content": PROMPT}]
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
start_time = time.time()
|
| 51 |
+
try:
|
| 52 |
+
with httpx.Client(timeout=60.0) as client:
|
| 53 |
+
response = client.post(base_url, headers=headers, json=payload)
|
| 54 |
+
response.raise_for_status()
|
| 55 |
+
|
| 56 |
+
duration = time.time() - start_time
|
| 57 |
+
data = response.json()
|
| 58 |
+
answer = data['choices'][0]['message']['content']
|
| 59 |
+
|
| 60 |
+
results.append({
|
| 61 |
+
"model": model,
|
| 62 |
+
"duration": duration,
|
| 63 |
+
"answer": answer,
|
| 64 |
+
"status": "success"
|
| 65 |
+
})
|
| 66 |
+
print(f" ✅ DONE ({duration:.2f}s)")
|
| 67 |
+
|
| 68 |
+
except Exception as e:
|
| 69 |
+
duration = time.time() - start_time
|
| 70 |
+
print(f" ❌ FAILED ({duration:.2f}s)")
|
| 71 |
+
results.append({
|
| 72 |
+
"model": model,
|
| 73 |
+
"duration": duration,
|
| 74 |
+
"error": str(e),
|
| 75 |
+
"status": "error"
|
| 76 |
+
})
|
| 77 |
+
|
| 78 |
+
print("\n" + "="*80)
|
| 79 |
+
print("📊 BÁO CÁO CHI TIẾT BENCHMARK")
|
| 80 |
+
print("="*80)
|
| 81 |
+
|
| 82 |
+
for res in results:
|
| 83 |
+
print(f"\n🔹 MODEL: {res['model']}")
|
| 84 |
+
print(f"⏱ Thời gian: {res['duration']:.2f} giây")
|
| 85 |
+
if res['status'] == "success":
|
| 86 |
+
print(f"🤖 Phản hồi:\n{res['answer']}")
|
| 87 |
+
else:
|
| 88 |
+
print(f"❌ Lỗi: {res.get('error')}")
|
| 89 |
+
print("-" * 40)
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
test_models()
|
tests/test_real_llm.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
import asyncio
|
| 3 |
+
import os
|
| 4 |
+
import logging
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
from app.llm_client import get_llm_client
|
| 7 |
+
|
| 8 |
+
# Setup logging to see the fallback process
|
| 9 |
+
logging.basicConfig(level=logging.INFO)
|
| 10 |
+
load_dotenv()
|
| 11 |
+
|
| 12 |
+
@pytest.mark.asyncio
|
| 13 |
+
async def test_real_llm():
|
| 14 |
+
client = get_llm_client()
|
| 15 |
+
|
| 16 |
+
print("\n--- Testing LLM Call (Complex Prompt) ---")
|
| 17 |
+
try:
|
| 18 |
+
content = await client.chat_completions_create(
|
| 19 |
+
messages=[
|
| 20 |
+
{"role": "system", "content": "You are a Geometry Expert. Formulate a step-by-step reasoning for calculating the distance between two points M and N where M is the midpoint of AB (len=10) and N is the midpoint of AD (len=20) in a rectangle ABCD. Use LaTeX for formulas."},
|
| 21 |
+
{"role": "user", "content": "Solve it carefully."}
|
| 22 |
+
]
|
| 23 |
+
)
|
| 24 |
+
print(f"\nResponse: {content}")
|
| 25 |
+
print("\n--- Test Completed Successfully ---")
|
| 26 |
+
except Exception as e:
|
| 27 |
+
print(f"\n--- Test Failed: {type(e).__name__}: {e} ---")
|
| 28 |
+
|
| 29 |
+
if __name__ == "__main__":
|
| 30 |
+
asyncio.run(test_real_llm())
|
tests/test_solver.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
| 4 |
+
|
| 5 |
+
from solver.engine import GeometryEngine
|
| 6 |
+
from solver.models import Point, Constraint
|
| 7 |
+
|
| 8 |
+
def test_triangle_abc():
|
| 9 |
+
engine = GeometryEngine()
|
| 10 |
+
|
| 11 |
+
# Triangle ABC: AB=5, AC=7, angle A=60
|
| 12 |
+
points = [
|
| 13 |
+
Point(id="A"),
|
| 14 |
+
Point(id="B"),
|
| 15 |
+
Point(id="C")
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
constraints = [
|
| 19 |
+
Constraint(type="length", targets=["A", "B"], value=5.0),
|
| 20 |
+
Constraint(type="length", targets=["A", "C"], value=7.0),
|
| 21 |
+
Constraint(type="angle", targets=["A"], value=60.0) # Angle at A
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
print("Solving for Triangle ABC (AB=5, AC=7, angle A=60)...")
|
| 25 |
+
results = engine.solve(points, constraints)
|
| 26 |
+
|
| 27 |
+
if results:
|
| 28 |
+
coords = results["coordinates"]
|
| 29 |
+
print("Success! Coordinates:")
|
| 30 |
+
for pid, c in coords.items():
|
| 31 |
+
print(f"Point {pid}: {c}")
|
| 32 |
+
|
| 33 |
+
# Verify distance AB
|
| 34 |
+
dist_ab = ((coords["B"][0] - coords["A"][0])**2 + (coords["B"][1] - coords["A"][1])**2)**0.5
|
| 35 |
+
print(f"Verified AB distance: {dist_ab:.2f}")
|
| 36 |
+
|
| 37 |
+
# Verify distance AC
|
| 38 |
+
dist_ac = ((coords["C"][0] - coords["A"][0])**2 + (coords["C"][1] - coords["A"][1])**2)**0.5
|
| 39 |
+
print(f"Verified AC distance: {dist_ac:.2f}")
|
| 40 |
+
else:
|
| 41 |
+
print("Solver failed.")
|
| 42 |
+
|
| 43 |
+
if __name__ == "__main__":
|
| 44 |
+
test_triangle_abc()
|