Cuong2004 commited on
Commit
25d12dc
·
0 Parent(s):

Deploy Worker from GitHub Actions

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +46 -0
  2. README.md +11 -0
  3. agents/geometry_agent.py +120 -0
  4. agents/knowledge_agent.py +135 -0
  5. agents/ocr_agent.py +185 -0
  6. agents/orchestrator.py +249 -0
  7. agents/parser_agent.py +106 -0
  8. agents/renderer_agent.py +249 -0
  9. agents/solver_agent.py +107 -0
  10. agents/torch_ultralytics_compat.py +33 -0
  11. app/dependencies.py +62 -0
  12. app/errors.py +59 -0
  13. app/llm_client.py +104 -0
  14. app/logging_setup.py +112 -0
  15. app/logutil.py +67 -0
  16. app/main.py +125 -0
  17. app/models/schemas.py +66 -0
  18. app/routers/__init__.py +1 -0
  19. app/routers/auth.py +23 -0
  20. app/routers/sessions.py +165 -0
  21. app/routers/solve.py +204 -0
  22. app/runtime_env.py +12 -0
  23. app/session_cache.py +48 -0
  24. app/supabase_client.py +37 -0
  25. app/url_utils.py +23 -0
  26. app/websocket_manager.py +40 -0
  27. clean_ports.sh +22 -0
  28. migrations/v4_migration.sql +95 -0
  29. requirements.txt +34 -0
  30. run_api_test.sh +65 -0
  31. run_full_api_test.sh +60 -0
  32. scripts/backend_test_suite.py +97 -0
  33. scripts/generate_report.py +73 -0
  34. scripts/prepare_api_test.py +31 -0
  35. scripts/prewarm_models.py +42 -0
  36. scripts/test_engine_direct.py +36 -0
  37. setup.sh +43 -0
  38. solver/dsl_parser.py +210 -0
  39. solver/engine.py +426 -0
  40. solver/models.py +13 -0
  41. tests/test_3d_solver.py +85 -0
  42. tests/test_advanced_geometry.py +102 -0
  43. tests/test_api_full_suite.py +237 -0
  44. tests/test_api_metadata_real.py +56 -0
  45. tests/test_api_real_e2e.py +75 -0
  46. tests/test_direct_task.py +70 -0
  47. tests/test_full_pipeline.py +237 -0
  48. tests/test_openrouter.py +92 -0
  49. tests/test_real_llm.py +30 -0
  50. tests/test_solver.py +44 -0
Dockerfile ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Same runtime as API; runs health endpoint + Celery worker (see worker_health.py)
2
+ FROM python:3.11-slim-bookworm
3
+
4
+ ENV PYTHONUNBUFFERED=1 \
5
+ PYTHONDONTWRITEBYTECODE=1 \
6
+ PIP_NO_CACHE_DIR=1 \
7
+ PIP_ROOT_USER_ACTION=ignore \
8
+ NO_ALBUMENTATIONS_UPDATE=1 \
9
+ OMP_NUM_THREADS=1 \
10
+ MKL_NUM_THREADS=1 \
11
+ OPENBLAS_NUM_THREADS=1
12
+
13
+ WORKDIR /app
14
+ ENV PYTHONPATH=/app
15
+
16
+ RUN apt-get update && apt-get install -y --no-install-recommends \
17
+ ffmpeg \
18
+ pkg-config \
19
+ cmake \
20
+ libcairo2 \
21
+ libcairo2-dev \
22
+ libpango-1.0-0 \
23
+ libpango1.0-dev \
24
+ libpangocairo-1.0-0 \
25
+ libgdk-pixbuf-2.0-0 \
26
+ libffi-dev \
27
+ python3-dev \
28
+ texlive-latex-base \
29
+ texlive-fonts-recommended \
30
+ texlive-latex-extra \
31
+ build-essential \
32
+ && rm -rf /var/lib/apt/lists/*
33
+
34
+ COPY requirements.txt .
35
+ RUN pip install --upgrade pip setuptools wheel \
36
+ && pip install -r requirements.txt
37
+
38
+ COPY . .
39
+
40
+ RUN python scripts/prewarm_models.py
41
+
42
+ ENV PORT=7860
43
+ EXPOSE 7860
44
+
45
+ ENTRYPOINT []
46
+ CMD ["sh", "-c", "exec python3 -u worker_health.py"]
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Math Solver Worker
3
+ emoji: 👷
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: docker
7
+ app_port: 7860
8
+ ---
9
+
10
+ # Math Solver Worker
11
+ This space hosts the Celery background worker for video rendering.
agents/geometry_agent.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+ from openai import AsyncOpenAI
5
+ from typing import Dict, Any
6
+ from dotenv import load_dotenv
7
+
8
+ load_dotenv()
9
+ logger = logging.getLogger(__name__)
10
+
11
+ from app.url_utils import openai_compatible_api_key, sanitize_env
12
+ from app.llm_client import get_llm_client
13
+
14
+
15
+ class GeometryAgent:
16
+ def __init__(self):
17
+ self.llm = get_llm_client()
18
+
19
+ async def generate_dsl(self, semantic_data: Dict[str, Any], previous_dsl: str = None) -> str:
20
+ logger.info("==[GeometryAgent] Generating DSL from semantic data==")
21
+ if previous_dsl:
22
+ logger.info(f"[GeometryAgent] Using previous DSL context (len={len(previous_dsl)})")
23
+
24
+ system_prompt = """
25
+ You are a Geometry DSL Generator. Convert semantic geometry data into a precise Geometry DSL program.
26
+
27
+ === MULTI-TURN CONTEXT ===
28
+ If a PREVIOUS DSL is provided, your job is to UPDATE or EXTEND it.
29
+ 1. DO NOT remove existing points unless the user explicitly asks to "redefine" or "move" them.
30
+ 2. Ensure new segments/points connect correctly to existing ones.
31
+ 3. Your output should be the ENTIRE updated DSL, not just the changes.
32
+
33
+ === DSL COMMANDS ===
34
+ POINT(A) — declare a point
35
+ POINT(A, x, y, z) — declare a point with explicit coordinates
36
+ LENGTH(AB, 5) — distance between A and B is 5 (2D/3D)
37
+ ANGLE(A, 90) — interior angle at vertex A is 90° (2D/3D)
38
+ PARALLEL(AB, CD) — segment AB is parallel to CD (2D/3D)
39
+ PERPENDICULAR(AB, CD) — segment AB is perpendicular to CD (2D/3D)
40
+ MIDPOINT(M, AB) — M is the midpoint of segment AB
41
+ SECTION(E, A, C, k) — E satisfies vector AE = k * vector AC (k is decimal)
42
+ LINE(A, B) — infinite line passing through A and B
43
+ RAY(A, B) — ray starting at A and passing through B
44
+ CIRCLE(O, 5) — circle with center O and radius 5 (2D)
45
+ SPHERE(O, 5) — sphere with center O and radius 5 (3D)
46
+ SEGMENT(M, N) — auxiliary segment MN to be drawn
47
+ POLYGON_ORDER(A, B, C, D) — the order in which vertices form the polygon boundary
48
+ TRIANGLE(ABC) — equilateral/arbitrary triangle
49
+ PYRAMID(S_ABCD) — pyramid with apex S and base ABCD
50
+ PRISM(ABC_DEF) — triangular prism
51
+
52
+ === RULES ===
53
+ 1. 3D Coordinates: Use POINT(A, x, y, z) if specific coordinates are given in the problem.
54
+ 2. Space Geometry: For pyramids/prisms, use the specialized commands.
55
+ 3. Primary Vertices: Always declare the main vertices of the shape (e.g., A, B, C, D) using POINT(X).
56
+ 4. POLYGON_ORDER: Always emit POLYGON_ORDER(...) for the main shape using ONLY these primary vertices.
57
+ 5. All Points: EVERY point mentioned (A, B, C, H, M, etc.) MUST be declared with POINT(Name) first.
58
+ 6. Altitudes/Perpendiculars: For an altitude AH to BC, use POINT(H) + PERPENDICULAR(AH, BC).
59
+ 7. Format: Output ONLY DSL lines — NO explanation, NO markdown, NO code blocks.
60
+
61
+ === SHAPE EXAMPLES ===
62
+
63
+ --- Case: Square Pyramid S.ABCD with side 10, height 15 ---
64
+ PYRAMID(S_ABCD)
65
+ POINT(A, 0, 0, 0)
66
+ POINT(B, 10, 0, 0)
67
+ POINT(C, 10, 10, 0)
68
+ POINT(D, 0, 10, 0)
69
+ POINT(S)
70
+ POINT(O)
71
+ SECTION(O, A, C, 0.5)
72
+ LENGTH(SO, 15)
73
+ PERPENDICULAR(SO, AC)
74
+ PERPENDICULAR(SO, AB)
75
+ POLYGON_ORDER(A, B, C, D)
76
+
77
+ --- Case: Right Triangle ABC at A, AB=3, AC=4, altitude AH ---
78
+ POLYGON_ORDER(A, B, C)
79
+ POINT(A)
80
+ POINT(B)
81
+ POINT(C)
82
+ POINT(H)
83
+ LENGTH(AB, 3)
84
+ LENGTH(AC, 4)
85
+ ANGLE(A, 90)
86
+ PERPENDICULAR(AH, BC)
87
+ SEGMENT(A, H)
88
+
89
+ --- Case: Rectangle ABCD with AB=5, AD=10 ---
90
+ POLYGON_ORDER(A, B, C, D)
91
+ POINT(A)
92
+ POINT(B)
93
+ POINT(C)
94
+ POINT(D)
95
+ LENGTH(AB, 5)
96
+ LENGTH(AD, 10)
97
+ PERPENDICULAR(AB, AD)
98
+ PARALLEL(AB, CD)
99
+ PARALLEL(AD, BC)
100
+
101
+ [Circle with center O radius 7]
102
+ POINT(O)
103
+ CIRCLE(O, 7)
104
+ """
105
+
106
+ user_content = f"Semantic Data: {json.dumps(semantic_data, ensure_ascii=False)}"
107
+ if previous_dsl:
108
+ user_content = f"PREVIOUS DSL:\n{previous_dsl}\n\nUPDATE WITH NEW DATA: {json.dumps(semantic_data, ensure_ascii=False)}"
109
+
110
+ logger.debug("[GeometryAgent] Calling LLM (Multi-Layer)...")
111
+ content = await self.llm.chat_completions_create(
112
+ messages=[
113
+ {"role": "system", "content": system_prompt},
114
+ {"role": "user", "content": user_content}
115
+ ]
116
+ )
117
+ dsl = content.strip() if content else ""
118
+ logger.info(f"[GeometryAgent] DSL generated ({len(dsl.splitlines())} lines).")
119
+ logger.debug(f"[GeometryAgent] DSL output:\n{dsl}")
120
+ return dsl
agents/knowledge_agent.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Dict, Any
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+
7
+ # ─── Shape rule registry ────────────────────────────────────────────────────
8
+ # Each entry: keyword list → augmentation function
9
+ # Augmentation receives (values: dict, text: str) and returns updated values dict.
10
+
11
+ class KnowledgeAgent:
12
+ """Knowledge Agent: Stores geometric theorems and common patterns to augment Parser output."""
13
+
14
+ def augment_semantic_data(self, semantic_data: Dict[str, Any]) -> Dict[str, Any]:
15
+ logger.info("==[KnowledgeAgent] Augmenting semantic data==")
16
+ text = str(semantic_data.get("input_text", "")).lower()
17
+ logger.debug(f"[KnowledgeAgent] Input text for matching: '{text[:200]}'")
18
+
19
+ shape_type = self._detect_shape(text, semantic_data.get("type", ""))
20
+ if shape_type:
21
+ semantic_data["type"] = shape_type
22
+ values = semantic_data.get("values", {})
23
+ values = self._augment_values(shape_type, values, text)
24
+ semantic_data["values"] = values
25
+ else:
26
+ logger.info("[KnowledgeAgent] No special rule matched. Returning data unchanged.")
27
+
28
+ logger.debug(f"[KnowledgeAgent] Output semantic data: {semantic_data}")
29
+ return semantic_data
30
+
31
+ # ─── Shape detection ────────────────────────────────────────────────────
32
+ def _detect_shape(self, text: str, llm_type: str) -> str | None:
33
+ """Detect shape from text keywords. LLM type provides a hint."""
34
+ checks = [
35
+ (["hình vuông", "square"], "square"),
36
+ (["hình chữ nhật", "rectangle"], "rectangle"),
37
+ (["hình thoi", "rhombus"], "rhombus"),
38
+ (["hình bình hành", "parallelogram"], "parallelogram"),
39
+ (["hình thang vuông"], "right_trapezoid"),
40
+ (["hình thang", "trapezoid", "trapezium"], "trapezoid"),
41
+ (["tam giác vuông", "right triangle"], "right_triangle"),
42
+ (["tam giác đều", "equilateral triangle", "equilateral"], "equilateral_triangle"),
43
+ (["tam giác cân", "isosceles"], "isosceles_triangle"),
44
+ (["tam giác", "triangle"], "triangle"),
45
+ (["đường tròn", "circle"], "circle"),
46
+ ]
47
+ for keywords, shape in checks:
48
+ if any(kw in text for kw in keywords):
49
+ logger.info(f"[KnowledgeAgent] Rule MATCH: '{shape}' detected (keyword match).")
50
+ return shape
51
+
52
+ # Fallback: trust LLM-detected type if it's a known type
53
+ known = {
54
+ "rectangle", "square", "rhombus", "parallelogram",
55
+ "trapezoid", "right_trapezoid", "triangle", "right_triangle",
56
+ "equilateral_triangle", "isosceles_triangle", "circle",
57
+ }
58
+ if llm_type in known:
59
+ logger.info(f"[KnowledgeAgent] Using LLM-detected type '{llm_type}'.")
60
+ return llm_type
61
+
62
+ return None
63
+
64
+ # ─── Value augmentation ──────────────────────────────────────────────────
65
+ def _augment_values(self, shape: str, values: dict, text: str) -> dict:
66
+ ab = values.get("AB")
67
+ ad = values.get("AD")
68
+ bc = values.get("BC")
69
+ cd = values.get("CD")
70
+
71
+ if shape == "rectangle":
72
+ if ab and ad:
73
+ values.setdefault("CD", ab)
74
+ values.setdefault("BC", ad)
75
+ values.setdefault("angle_A", 90)
76
+ logger.info(f"[KnowledgeAgent] Rectangle: AB=CD={ab}, AD=BC={ad}, angle_A=90°")
77
+ else:
78
+ values.setdefault("angle_A", 90)
79
+
80
+ elif shape == "square":
81
+ side = ab or ad or bc or cd or values.get("side")
82
+ if side:
83
+ values.update({"AB": side, "AD": side, "angle_A": 90})
84
+ logger.info(f"[KnowledgeAgent] Square: side={side}, angle_A=90°")
85
+ else:
86
+ values.setdefault("angle_A", 90)
87
+
88
+ elif shape == "rhombus":
89
+ side = ab or values.get("side")
90
+ if side:
91
+ values.update({"AB": side, "BC": side, "CD": side, "DA": side})
92
+ logger.info(f"[KnowledgeAgent] Rhombus: all sides={side}")
93
+
94
+ elif shape == "parallelogram":
95
+ if ab:
96
+ values.setdefault("CD", ab)
97
+ if ad:
98
+ values.setdefault("BC", ad)
99
+ logger.info(f"[KnowledgeAgent] Parallelogram: AB||CD, AD||BC")
100
+
101
+ elif shape == "trapezoid":
102
+ logger.info("[KnowledgeAgent] Trapezoid: AB||CD (bottom||top)")
103
+
104
+ elif shape == "right_trapezoid":
105
+ logger.info("[KnowledgeAgent] Right trapezoid: AB||CD, AD⊥AB")
106
+ values.setdefault("angle_A", 90)
107
+
108
+ elif shape == "equilateral_triangle":
109
+ side = ab or values.get("side")
110
+ if side:
111
+ values.update({"AB": side, "BC": side, "CA": side, "angle_A": 60})
112
+ logger.info(f"[KnowledgeAgent] Equilateral triangle: all sides={side}, angle_A=60°")
113
+
114
+ elif shape == "right_triangle":
115
+ # Try to infer which vertex is the right angle
116
+ rt_vertex = _detect_right_angle_vertex(text)
117
+ values.setdefault(f"angle_{rt_vertex}", 90)
118
+ logger.info(f"[KnowledgeAgent] Right triangle: angle_{rt_vertex}=90°")
119
+
120
+ elif shape == "isosceles_triangle":
121
+ logger.info("[KnowledgeAgent] Isosceles triangle: AB=AC (default, LLM may override)")
122
+
123
+ elif shape == "circle":
124
+ logger.info("[KnowledgeAgent] Circle detected — no side augmentation needed.")
125
+
126
+ return values
127
+
128
+
129
+ def _detect_right_angle_vertex(text: str) -> str:
130
+ """Heuristic: detect which vertex is right angle from text."""
131
+ for vertex in ["A", "B", "C", "D"]:
132
+ patterns = [f"vuông tại {vertex}", f"góc {vertex} vuông", f"right angle at {vertex}"]
133
+ if any(p.lower() in text for p in patterns):
134
+ return vertex
135
+ return "A" # default
agents/ocr_agent.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import asyncio
4
+ from typing import List, Dict, Any
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class ImprovedOCRAgent:
10
+ """
11
+ Advanced OCR Agent using a hybrid pipeline:
12
+ 1. YOLO for layout analysis (text vs formula).
13
+ 2. PaddleOCR for Vietnamese text extraction.
14
+ 3. Pix2Tex for LaTeX formula extraction.
15
+ 4. MegaLLM for semantic correction and formatting.
16
+ """
17
+
18
+ def __init__(self):
19
+ logger.info("[ImprovedOCRAgent] Initializing engines and client...")
20
+
21
+ from app.llm_client import get_llm_client
22
+ self.llm = get_llm_client()
23
+ logger.info("[ImprovedOCRAgent] Multi-Layer LLM Client initialized.")
24
+
25
+ try:
26
+ from agents.torch_ultralytics_compat import allow_ultralytics_weights
27
+ from ultralytics import YOLO
28
+
29
+ allow_ultralytics_weights()
30
+ logger.info("[ImprovedOCRAgent] Loading YOLO...")
31
+ self.layout_model = YOLO("yolov8n.pt")
32
+ logger.info("[ImprovedOCRAgent] YOLO initialized.")
33
+ except Exception as e:
34
+ logger.error("[ImprovedOCRAgent] YOLO init failed: %s", e)
35
+ self.layout_model = None
36
+
37
+ try:
38
+ from paddleocr import PaddleOCR
39
+
40
+ logger.info("[ImprovedOCRAgent] Loading PaddleOCR...")
41
+ self.text_model = PaddleOCR(use_angle_cls=True, lang="vi")
42
+ logger.info("[ImprovedOCRAgent] PaddleOCR (vi) initialized.")
43
+ except Exception as e:
44
+ logger.error("[ImprovedOCRAgent] PaddleOCR init failed: %s", e)
45
+ self.text_model = None
46
+
47
+ try:
48
+ from pix2tex.cli import LatexOCR
49
+
50
+ logger.info("[ImprovedOCRAgent] Loading Pix2Tex...")
51
+ self.math_model = LatexOCR()
52
+ logger.info("[ImprovedOCRAgent] Pix2Tex initialized.")
53
+ except Exception as e:
54
+ logger.error("[ImprovedOCRAgent] Pix2Tex init failed: %s", e)
55
+ self.math_model = None
56
+
57
+ async def process_image(self, image_path: str) -> str:
58
+ logger.info("==[ImprovedOCRAgent] Processing: %s==", image_path)
59
+
60
+ if not os.path.exists(image_path):
61
+ return f"Error: File {image_path} not found."
62
+
63
+ raw_fragments: List[Dict[str, Any]] = []
64
+
65
+ if self.text_model:
66
+ logger.info("[ImprovedOCRAgent] Running PaddleOCR on %s...", image_path)
67
+ result = self.text_model.ocr(image_path)
68
+ logger.info("[ImprovedOCRAgent] PaddleOCR raw result: %s", result)
69
+
70
+ if not result:
71
+ logger.warning("[ImprovedOCRAgent] PaddleOCR returned no results.")
72
+ return ""
73
+
74
+ if isinstance(result[0], dict):
75
+ res_dict = result[0]
76
+ rec_texts = res_dict.get("rec_texts", [])
77
+ rec_scores = res_dict.get("rec_scores", [])
78
+ rec_polys = res_dict.get("rec_polys", [])
79
+
80
+ for i in range(len(rec_texts)):
81
+ text = rec_texts[i]
82
+ bbox = rec_polys[i]
83
+ _ = rec_scores[i]
84
+
85
+ y_top = int(min(p[1] for p in bbox)) if hasattr(bbox, "__iter__") else 0
86
+
87
+ is_math_hint = any(
88
+ c in text for c in ["\\", "^", "_", "{", "}", "=", "+", "-", "*", "/"]
89
+ )
90
+ if is_math_hint and self.math_model:
91
+ pass
92
+
93
+ raw_fragments.append({"y": y_top, "content": text, "type": "text"})
94
+ elif isinstance(result[0], list):
95
+ for line in result[0]:
96
+ bbox = line[0]
97
+ text = line[1][0]
98
+ _ = line[1][1]
99
+
100
+ y_top = bbox[0][1]
101
+ raw_fragments.append({"y": y_top, "content": text, "type": "text"})
102
+
103
+ raw_fragments.sort(key=lambda x: x["y"])
104
+ combined_text = "\n".join([f["content"] for f in raw_fragments])
105
+
106
+ logger.info(
107
+ "[ImprovedOCRAgent] Raw OCR output assembled:\n---\n%s\n---", combined_text
108
+ )
109
+
110
+ if not combined_text.strip():
111
+ logger.warning("[ImprovedOCRAgent] No text detected to refine.")
112
+ return ""
113
+
114
+ try:
115
+ logger.info("[ImprovedOCRAgent] Sending to MegaLLM for refinement...")
116
+ refined_text = await asyncio.wait_for(
117
+ self.refine_with_llm(combined_text), timeout=30.0
118
+ )
119
+ return refined_text
120
+ except asyncio.TimeoutError:
121
+ logger.error("[ImprovedOCRAgent] MegaLLM refinement timed out.")
122
+ return combined_text
123
+ except Exception as e:
124
+ logger.error("[ImprovedOCRAgent] MegaLLM refinement failed: %s", e)
125
+ return combined_text
126
+
127
+ async def refine_with_llm(self, text: str) -> str:
128
+ if not text.strip():
129
+ return ""
130
+
131
+ prompt = f"""Bạn là một chuyên gia số hóa tài liệu toán học.
132
+ Dưới đây là kết quả OCR thô từ một trang sách toán Tiếng Việt.
133
+ Kết quả này có thể chứa lỗi chính tả, lỗi định dạng mã LaTeX, hoặc bị ngắt quãng không logic.
134
+
135
+ Nhiệm vụ của bạn:
136
+ 1. Sửa lỗi chính tả tiếng Việt.
137
+ 2. Đảm bảo các công thức toán học được viết đúng định dạng LaTeX và nằm trong cặp dấu $...$.
138
+ 3. Giữ nguyên cấu trúc logic của bài toán.
139
+ 4. Trả về nội dung đã được làm sạch dưới dạng Markdown.
140
+
141
+ Nội dung OCR thô:
142
+ ---
143
+ {text}
144
+ ---
145
+
146
+ Kết quả làm sạch:"""
147
+
148
+ try:
149
+ refined = await self.llm.chat_completions_create(
150
+ messages=[{"role": "user", "content": prompt}],
151
+ temperature=0.1,
152
+ )
153
+ logger.info("[ImprovedOCRAgent] LLM refinement complete.")
154
+ return refined
155
+ except Exception as e:
156
+ logger.error("[ImprovedOCRAgent] LLM refinement failed: %s", e)
157
+ return text
158
+
159
+ async def process_url(self, url: str) -> str:
160
+ import httpx
161
+
162
+ from app.url_utils import sanitize_url
163
+
164
+ url = sanitize_url(url)
165
+ if not url:
166
+ return "Error: Empty image URL after cleanup."
167
+
168
+ async with httpx.AsyncClient() as client:
169
+ resp = await client.get(url)
170
+ if resp.status_code == 200:
171
+ temp_path = "temp_url_image.png"
172
+ with open(temp_path, "wb") as f:
173
+ f.write(resp.content)
174
+ try:
175
+ return await self.process_image(temp_path)
176
+ finally:
177
+ if os.path.exists(temp_path):
178
+ os.remove(temp_path)
179
+ return f"Error: Failed to fetch image from URL {url}"
180
+
181
+
182
+ class OCRAgent(ImprovedOCRAgent):
183
+ """Alias for compatibility with existing code."""
184
+
185
+ pass
agents/orchestrator.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from typing import Any, Dict
4
+
5
+ from agents.geometry_agent import GeometryAgent
6
+ from agents.knowledge_agent import KnowledgeAgent
7
+ from agents.ocr_agent import OCRAgent
8
+ from agents.parser_agent import ParserAgent
9
+ from agents.renderer_agent import RendererAgent
10
+ from agents.solver_agent import SolverAgent
11
+ from app.logutil import log_step
12
+ from solver.dsl_parser import DSLParser
13
+ from solver.engine import GeometryEngine
14
+ from worker.celery_app import BROKER_URL
15
+ from worker.tasks import render_geometry_video
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ _CLIP = 2000
20
+
21
+
22
+ def _clip(val: Any, n: int = _CLIP) -> str | None:
23
+ if val is None:
24
+ return None
25
+ if isinstance(val, str):
26
+ s = val
27
+ else:
28
+ s = json.dumps(val, ensure_ascii=False, default=str)
29
+ return s if len(s) <= n else s[:n] + "…"
30
+
31
+
32
+ def _step_io(step: str, input_val: Any = None, output_val: Any = None) -> None:
33
+ """Debug: chỉ input/output (đã cắt), tránh dump dài dòng không cần thiết."""
34
+ log_step(step, input=_clip(input_val), output=_clip(output_val))
35
+
36
+
37
+ class Orchestrator:
38
+ def __init__(self):
39
+ self.parser_agent = ParserAgent()
40
+ self.geometry_agent = GeometryAgent()
41
+ self.ocr_agent = OCRAgent()
42
+ self.knowledge_agent = KnowledgeAgent()
43
+ self.renderer_agent = RendererAgent()
44
+ self.solver_agent = SolverAgent()
45
+ self.solver_engine = GeometryEngine()
46
+ self.dsl_parser = DSLParser()
47
+
48
+ def _generate_step_description(self, semantic_json: Dict[str, Any], engine_result: Dict[str, Any]) -> str:
49
+ """Tạo mô tả từng bước vẽ dựa trên kết quả của engine."""
50
+ analysis = semantic_json.get("analysis", "")
51
+ if not analysis:
52
+ analysis = f"Giải bài toán về {semantic_json.get('type', 'hình học')}."
53
+
54
+ steps = ["\n\n**Các bước dựng hình:**"]
55
+ drawing_phases = engine_result.get("drawing_phases", [])
56
+
57
+ for phase in drawing_phases:
58
+ label = phase.get("label", f"Giai đoạn {phase['phase']}")
59
+ points = ", ".join(phase.get("points", []))
60
+ segments = ", ".join([f"{s[0]}{s[1]}" for s in phase.get("segments", [])])
61
+
62
+ step_text = f"- **{label}**:"
63
+ if points:
64
+ step_text += f" Xác định các điểm {points}."
65
+ if segments:
66
+ step_text += f" Vẽ các đoạn thẳng {segments}."
67
+ steps.append(step_text)
68
+
69
+ circles = engine_result.get("circles", [])
70
+ for c in circles:
71
+ steps.append(f"- **Đường tròn**: Vẽ đường tròn tâm {c['center']} bán kính {c['radius']}.")
72
+
73
+ return analysis + "\n".join(steps)
74
+
75
+ async def run(
76
+ self,
77
+ text: str,
78
+ image_url: str = None,
79
+ job_id: str = None,
80
+ session_id: str = None,
81
+ status_callback=None,
82
+ request_video: bool = False,
83
+ history: list = None,
84
+ ) -> Dict[str, Any]:
85
+ """
86
+ Run the full pipeline. Optional history allows context-aware solving.
87
+ """
88
+ _step_io(
89
+ "orchestrate_start",
90
+ input_val={
91
+ "job_id": job_id,
92
+ "text_len": len(text or ""),
93
+ "image_url": image_url,
94
+ "request_video": request_video,
95
+ "history_len": len(history or []),
96
+ },
97
+ output_val=None,
98
+ )
99
+
100
+ if status_callback:
101
+ await status_callback("processing")
102
+
103
+ # 1. Extract context from history (if any)
104
+ previous_context = None
105
+ if history:
106
+ # Look for the last assistant message with geometry data
107
+ for msg in reversed(history):
108
+ if msg.get("role") == "assistant" and msg.get("metadata", {}).get("geometry_dsl"):
109
+ previous_context = {
110
+ "geometry_dsl": msg["metadata"]["geometry_dsl"],
111
+ "coordinates": msg["metadata"].get("coordinates", {}),
112
+ "analysis": msg.get("content", ""),
113
+ }
114
+ break
115
+
116
+ if previous_context:
117
+ _step_io("context_found", input_val=None, output_val={"dsl_len": len(previous_context["geometry_dsl"])})
118
+
119
+ # 2. Gather input text (OCR or direct)
120
+ input_text = text
121
+ if image_url:
122
+ input_text = await self.ocr_agent.process_url(image_url)
123
+ _step_io("step1_ocr", input_val=image_url, output_val=input_text)
124
+ else:
125
+ _step_io("step1_ocr", input_val="(no image)", output_val=text)
126
+
127
+ feedback = None
128
+ MAX_RETRIES = 2
129
+
130
+ for attempt in range(MAX_RETRIES + 1):
131
+ _step_io(
132
+ "attempt",
133
+ input_val=f"{attempt + 1}/{MAX_RETRIES + 1}",
134
+ output_val=None,
135
+ )
136
+ if status_callback:
137
+ await status_callback("solving")
138
+
139
+ # Parser with context
140
+ _step_io("step2_parse", input_val=f"{input_text[:50]}...", output_val=None)
141
+ semantic_json = await self.parser_agent.process(input_text, feedback=feedback, context=previous_context)
142
+ semantic_json["input_text"] = input_text
143
+ _step_io("step2_parse", input_val=None, output_val=semantic_json)
144
+
145
+ # Knowledge augmentation
146
+ _step_io("step3_knowledge", input_val=semantic_json, output_val=None)
147
+ semantic_json = self.knowledge_agent.augment_semantic_data(semantic_json)
148
+ _step_io("step3_knowledge", input_val=None, output_val=semantic_json)
149
+
150
+ # Geometry DSL with context (passing previous DSL to guide generation)
151
+ _step_io("step4_geometry_dsl", input_val=semantic_json, output_val=None)
152
+ dsl_code = await self.geometry_agent.generate_dsl(
153
+ semantic_json,
154
+ previous_dsl=previous_context["geometry_dsl"] if previous_context else None
155
+ )
156
+ _step_io("step4_geometry_dsl", input_val=None, output_val=dsl_code)
157
+
158
+ _step_io("step5_dsl_parse", input_val=dsl_code, output_val=None)
159
+ points, constraints, is_3d = self.dsl_parser.parse(dsl_code)
160
+ _step_io(
161
+ "step5_dsl_parse",
162
+ input_val=None,
163
+ output_val={
164
+ "points": len(points),
165
+ "constraints": len(constraints),
166
+ "is_3d": is_3d,
167
+ },
168
+ )
169
+
170
+ _step_io("step6_solve", input_val=f"{len(points)} pts / {len(constraints)} cons (is_3d={is_3d})", output_val=None)
171
+ import anyio
172
+ engine_result = await anyio.to_thread.run_sync(self.solver_engine.solve, points, constraints, is_3d)
173
+
174
+ if engine_result:
175
+ coordinates = engine_result.get("coordinates")
176
+ _step_io("step6_solve", input_val=None, output_val=coordinates)
177
+ break
178
+
179
+ feedback = "Geometry solver failed to find a valid solution for the given constraints. Parallelism or lengths might be inconsistent."
180
+ _step_io(
181
+ "step6_solve",
182
+ input_val=f"attempt {attempt + 1}",
183
+ output_val=feedback,
184
+ )
185
+ if attempt == MAX_RETRIES:
186
+ _step_io(
187
+ "orchestrate_abort",
188
+ input_val=None,
189
+ output_val="solver_exhausted_retries",
190
+ )
191
+ return {
192
+ "error": "Solver failed after multiple attempts.",
193
+ "last_dsl": dsl_code,
194
+ }
195
+
196
+ status = "success"
197
+ if request_video:
198
+ try:
199
+ result_payload = {
200
+ "geometry_dsl": dsl_code,
201
+ "coordinates": coordinates,
202
+ "polygon_order": engine_result.get("polygon_order", []),
203
+ "drawing_phases": engine_result.get("drawing_phases", []),
204
+ "circles": engine_result.get("circles", []),
205
+ "lines": engine_result.get("lines", []),
206
+ "rays": engine_result.get("rays", []),
207
+ "semantic": semantic_json,
208
+ "semantic_analysis": semantic_json.get("analysis") or semantic_json.get("input_text", ""),
209
+ "session_id": session_id,
210
+ }
211
+ task = render_geometry_video.delay(job_id, result_payload)
212
+ status = "rendering_queued"
213
+ _step_io(
214
+ "step7_video",
215
+ input_val={"job_id": job_id, "broker": BROKER_URL.split("@")[-1] if "@" in BROKER_URL else BROKER_URL},
216
+ output_val={"task_id": str(task.id), "status": status},
217
+ )
218
+ except Exception as e:
219
+ logger.exception("Celery queue failed for job %s", job_id)
220
+ _step_io("step7_video", input_val=job_id, output_val=str(e))
221
+ status = "success"
222
+ else:
223
+ _step_io("step7_video", input_val=request_video, output_val="skipped")
224
+
225
+ _step_io("orchestrate_done", input_val=job_id, output_val=status)
226
+
227
+ # 8. Solution calculation (New in v5.1)
228
+ solution = None
229
+ if engine_result:
230
+ _step_io("step8_solve_math", input_val=semantic_json.get("target_question"), output_val=None)
231
+ solution = await self.solver_agent.solve(semantic_json, engine_result)
232
+ _step_io("step8_solve_math", input_val=None, output_val=solution.get("answer"))
233
+
234
+ final_analysis = self._generate_step_description(semantic_json, engine_result)
235
+
236
+ return {
237
+ "status": status,
238
+ "job_id": job_id,
239
+ "geometry_dsl": dsl_code,
240
+ "coordinates": coordinates,
241
+ "polygon_order": engine_result.get("polygon_order", []),
242
+ "circles": engine_result.get("circles", []),
243
+ "lines": engine_result.get("lines", []),
244
+ "rays": engine_result.get("rays", []),
245
+ "drawing_phases": engine_result.get("drawing_phases", []),
246
+ "semantic": semantic_json,
247
+ "semantic_analysis": final_analysis,
248
+ "solution": solution,
249
+ }
agents/parser_agent.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import logging
4
+ from openai import AsyncOpenAI
5
+ from typing import Dict, Any
6
+ from dotenv import load_dotenv
7
+
8
+ load_dotenv()
9
+ logger = logging.getLogger(__name__)
10
+
11
+ from app.url_utils import openai_compatible_api_key, sanitize_env
12
+
13
+
14
+ from app.llm_client import get_llm_client
15
+
16
+
17
+ class ParserAgent:
18
+ def __init__(self):
19
+ self.llm = get_llm_client()
20
+
21
+ async def process(self, text: str, feedback: str = None, context: Dict[str, Any] = None) -> Dict[str, Any]:
22
+ logger.info(f"==[ParserAgent] Processing input (len={len(text)})==")
23
+ if feedback:
24
+ logger.warning(f"[ParserAgent] Feedback from previous attempt: {feedback}")
25
+ if context:
26
+ logger.info(f"[ParserAgent] Using previous context (dsl_len={len(context.get('geometry_dsl', ''))})")
27
+
28
+ system_prompt = """
29
+ You are a Geometry Parser Agent. Extract geometric entities and constraints from Vietnamese/LaTeX math problem text.
30
+
31
+ === CONTEXT AWARENESS ===
32
+ If previous context is provided, it means this is a follow-up request.
33
+ - Combine old entities with new ones.
34
+ - Update 'analysis' to reflect the entire problem state.
35
+
36
+ Output ONLY a JSON object with this EXACT structure (no extra keys, no markdown):
37
+ {
38
+ "entities": ["Point A", "Point B", ...],
39
+ "type": "pyramid|prism|sphere|rectangle|triangle|circle|parallelogram|trapezoid|square|rhombus|general",
40
+ "values": {"AB": 5, "SO": 15, "radius": 3},
41
+ "target_question": "Câu hỏi cụ thể cần giải (ví dụ: 'Tính diện tích tam giác ABC'). NẾU KHÔNG CÓ CÂU HỎI THÌ ĐỂ null.",
42
+ "analysis": "Tóm tắt ngắn gọn toàn bộ bài toán sau khi đã cập nhật các yêu cầu mới bằng tiếng Việt."
43
+ }
44
+ Rules:
45
+ - "analysis" MUST be a meaningful and UP-TO-DATE summary of the problem in Vietnamese.
46
+ - "target_question" must be concise.
47
+ - Include midpoints, auxiliary points in "entities" if mentioned.
48
+ - If feedback is provided, correct your previous output accordingly.
49
+ """
50
+
51
+ user_content = f"Text: {text}"
52
+ if context:
53
+ user_content = f"PREVIOUS ANALYSIS: {context.get('analysis')}\nNEW REQUEST: {text}"
54
+
55
+ if feedback:
56
+ user_content += f"\nFeedback from previous attempt: {feedback}. Please correct the constraints."
57
+
58
+ logger.debug("[ParserAgent] Calling LLM (Multi-Layer)...")
59
+ raw = await self.llm.chat_completions_create(
60
+ messages=[
61
+ {"role": "system", "content": system_prompt},
62
+ {"role": "user", "content": user_content}
63
+ ],
64
+ response_format={"type": "json_object"}
65
+ )
66
+
67
+ # Pre-process raw string: extract the JSON block if present
68
+ import re
69
+ clean_raw = raw.strip()
70
+ # Handle potential markdown code blocks
71
+ if clean_raw.startswith("```"):
72
+ import re
73
+ match = re.search(r"```(?:json)?\s*(.*?)\s*```", clean_raw, re.DOTALL)
74
+ if match:
75
+ clean_raw = match.group(1).strip()
76
+
77
+ try:
78
+ result = json.loads(clean_raw)
79
+ except json.JSONDecodeError as e:
80
+ logger.error(f"[ParserAgent] JSON Parse Error: {e}. Attempting regex fallback...")
81
+ import re
82
+ json_match = re.search(r'(\{.*\})', clean_raw, re.DOTALL)
83
+ if json_match:
84
+ try:
85
+ # Handle single quotes if present (common LLM failure)
86
+ json_str = json_match.group(1)
87
+ if "'" in json_str and '"' not in json_str:
88
+ json_str = json_str.replace("'", '"')
89
+ result = json.loads(json_str)
90
+ except:
91
+ result = None
92
+ else:
93
+ result = None
94
+
95
+ if not result:
96
+ # Fallback for critical failure
97
+ result = {
98
+ "entities": [],
99
+ "type": "general",
100
+ "values": {},
101
+ "target_question": None,
102
+ "analysis": text
103
+ }
104
+ logger.info(f"[ParserAgent] LLM response received.")
105
+ logger.debug(f"[ParserAgent] Parsed JSON: {json.dumps(result, ensure_ascii=False, indent=2)}")
106
+ return result
agents/renderer_agent.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import glob
4
+ import string
5
+ from typing import Dict, Any, List
6
+
7
+
8
+ class RendererAgent:
9
+ """
10
+ Renderer Agent — generates Manim scripts from geometry data.
11
+
12
+ Drawing happens in phases:
13
+ Phase 1: Main polygon (base shape with correct vertex order)
14
+ Phase 2: Auxiliary points and segments (midpoints, derived segments)
15
+ Phase 3: Labels for all points
16
+ """
17
+
18
+ def generate_manim_script(self, data: Dict[str, Any]) -> str:
19
+ coords: Dict[str, List[float]] = data.get("coordinates", {})
20
+ polygon_order: List[str] = data.get("polygon_order", [])
21
+ circles_meta: List[Dict] = data.get("circles", [])
22
+ lines_meta: List[List[str]] = data.get("lines", [])
23
+ rays_meta: List[List[str]] = data.get("rays", [])
24
+ drawing_phases: List[Dict] = data.get("drawing_phases", [])
25
+ semantic: Dict[str, Any] = data.get("semantic", {})
26
+ shape_type = semantic.get("type", "").lower()
27
+
28
+ # ── Detect 3D Context ────────────────────────────────────────────────
29
+ is_3d = False
30
+ for pos in coords.values():
31
+ if len(pos) >= 3 and abs(pos[2]) > 0.001:
32
+ is_3d = True
33
+ break
34
+ if shape_type in ["pyramid", "prism", "sphere"]:
35
+ is_3d = True
36
+
37
+ # ── Fallback: infer polygon_order from coords keys (alphabetical uppercase) ──
38
+ if not polygon_order:
39
+ base = sorted(
40
+ [pid for pid in coords if pid in string.ascii_uppercase],
41
+ key=lambda p: string.ascii_uppercase.index(p)
42
+ )
43
+ polygon_order = base
44
+
45
+ # Separate base points from derived (multi-char or lowercase)
46
+ base_ids = [pid for pid in polygon_order if pid in coords]
47
+ derived_ids = [pid for pid in coords if pid not in polygon_order]
48
+
49
+ scene_base = "ThreeDScene" if is_3d else "MovingCameraScene"
50
+ lines = [
51
+ "from manim import *",
52
+ "",
53
+ f"class GeometryScene({scene_base}):",
54
+ " def construct(self):",
55
+ ]
56
+
57
+ if is_3d:
58
+ lines.append(" # 3D Setup")
59
+ lines.append(" self.set_camera_orientation(phi=75*DEGREES, theta=-45*DEGREES)")
60
+ lines.append(" axes = ThreeDAxes(axis_config={'stroke_width': 1})")
61
+ lines.append(" axes.set_opacity(0.3)")
62
+ lines.append(" self.add(axes)")
63
+ lines.append(" self.begin_ambient_camera_rotation(rate=0.1)")
64
+ lines.append("")
65
+
66
+ # ── Declare all dots and labels ───────────────────────────────────────
67
+ for pid, pos in coords.items():
68
+ x, y, z = 0, 0, 0
69
+ if len(pos) >= 1: x = round(pos[0], 4)
70
+ if len(pos) >= 2: y = round(pos[1], 4)
71
+ if len(pos) >= 3: z = round(pos[2], 4)
72
+
73
+ dot_class = "Dot3D" if is_3d else "Dot"
74
+ lines.append(f" p_{pid} = {dot_class}(point=[{x}, {y}, {z}], color=WHITE, radius=0.08)")
75
+
76
+ if is_3d:
77
+ lines.append(
78
+ f" l_{pid} = Text('{pid}', font_size=20, color=WHITE)"
79
+ f".move_to(p_{pid}.get_center() + [0.2, 0.2, 0.2])"
80
+ )
81
+ # Ensure labels follow camera in 3D (fixed orientation)
82
+ lines.append(f" self.add_fixed_orientation_mobjects(l_{pid})")
83
+ else:
84
+ lines.append(
85
+ f" l_{pid} = Text('{pid}', font_size=22, color=WHITE)"
86
+ f".next_to(p_{pid}, UR, buff=0.15)"
87
+ )
88
+
89
+ # ── 3D Shape Special: Pyramid/Prism Faces ────────────────────────────
90
+ if is_3d and shape_type == "pyramid" and len(base_ids) >= 3:
91
+ # Find apex (usually 'S')
92
+ apex_id = "S" if "S" in coords else derived_ids[0] if derived_ids else None
93
+ if apex_id:
94
+ # Draw base face
95
+ base_pts = ", ".join([f"p_{pid}.get_center()" for pid in base_ids])
96
+ lines.append(f" base_face = Polygon({base_pts}, color=BLUE, fill_opacity=0.1)")
97
+ lines.append(" self.play(Create(base_face), run_time=1.0)")
98
+
99
+ # Draw side faces
100
+ for i in range(len(base_ids)):
101
+ p1 = base_ids[i]
102
+ p2 = base_ids[(i+1)%len(base_ids)]
103
+ face_pts = f"p_{apex_id}.get_center(), p_{p1}.get_center(), p_{p2}.get_center()"
104
+ lines.append(f" side_{i} = Polygon({face_pts}, color=BLUE, stroke_width=1, fill_opacity=0.05)")
105
+ lines.append(f" self.play(Create(side_{i}), run_time=0.5)")
106
+
107
+ # ── Circles ──────────────────────────────────────────────────────────
108
+ for i, c in enumerate(circles_meta):
109
+ center = c["center"]
110
+ r = c["radius"]
111
+ if center in coords:
112
+ cx, cy, cz = 0, 0, 0
113
+ pos = coords[center]
114
+ if len(pos) >= 1: cx = round(pos[0], 4)
115
+ if len(pos) >= 2: cy = round(pos[1], 4)
116
+ if len(pos) >= 3: cz = round(pos[2], 4)
117
+ lines.append(
118
+ f" circle_{i} = Circle(radius={r}, color=BLUE)"
119
+ f".move_to([{cx}, {cy}, {cz}])"
120
+ )
121
+
122
+ # ── Infinite Lines & Rays ────────────────────────────────────────────
123
+ # (Standard Line works for 3D coordinates in Manim)
124
+ for i, (p1, p2) in enumerate(lines_meta):
125
+ if p1 in coords and p2 in coords:
126
+ lines.append(
127
+ f" line_ext_{i} = Line(p_{p1}.get_center(), p_{p2}.get_center(), color=GRAY_D, stroke_width=2)"
128
+ f".scale(20)"
129
+ )
130
+
131
+ for i, (p1, p2) in enumerate(rays_meta):
132
+ if p1 in coords and p2 in coords:
133
+ lines.append(
134
+ f" ray_{i} = Line(p_{p1}.get_center(), p_{p1}.get_center() + 15 * (p_{p2}.get_center() - p_{p1}.get_center()),"
135
+ f" color=GRAY_C, stroke_width=2)"
136
+ )
137
+
138
+ # ── Camera auto-fit group (Only for 2D) ──────────────────────────────
139
+ if not is_3d:
140
+ all_dot_names = [f"p_{pid}" for pid in coords]
141
+ all_names_str = ", ".join(all_dot_names)
142
+ lines.append(f" _all = VGroup({all_names_str})")
143
+ lines.append(" self.camera.frame.set_width(max(_all.width * 2.0, 8))")
144
+ lines.append(" self.camera.frame.move_to(_all)")
145
+ lines.append("")
146
+
147
+ # ── Phase 1: Base polygon ─────────────────────────────────────────────
148
+ if len(base_ids) >= 3:
149
+ pts_str = ", ".join([f"p_{pid}.get_center()" for pid in base_ids])
150
+ lines.append(f" poly = Polygon({pts_str}, color=BLUE, fill_color=BLUE, fill_opacity=0.15)")
151
+ lines.append(" self.play(Create(poly), run_time=1.5)")
152
+ elif len(base_ids) == 2:
153
+ p1, p2 = base_ids
154
+ lines.append(f" base_line = Line(p_{p1}.get_center(), p_{p2}.get_center(), color=BLUE)")
155
+ lines.append(" self.play(Create(base_line), run_time=1.0)")
156
+
157
+ # Draw base points
158
+ if base_ids:
159
+ base_dots_str = ", ".join([f"p_{pid}" for pid in base_ids])
160
+ lines.append(f" self.play(FadeIn(VGroup({base_dots_str})), run_time=0.5)")
161
+ lines.append(" self.wait(0.5)")
162
+
163
+ # ── Phase 2: Auxiliary points and segments ────────────────────────────
164
+ if derived_ids:
165
+ derived_dots_str = ", ".join([f"p_{pid}" for pid in derived_ids])
166
+ lines.append(f" self.play(FadeIn(VGroup({derived_dots_str})), run_time=0.8)")
167
+
168
+ # Segments from drawing_phases
169
+ segment_lines = []
170
+ for phase in drawing_phases:
171
+ if phase.get("phase") == 2:
172
+ for seg in phase.get("segments", []):
173
+ if len(seg) == 2 and seg[0] in coords and seg[1] in coords:
174
+ p1, p2 = seg[0], seg[1]
175
+ seg_var = f"seg_{p1}_{p2}"
176
+ lines.append(
177
+ f" {seg_var} = Line(p_{p1}.get_center(), p_{p2}.get_center(),"
178
+ f" color=YELLOW)"
179
+ )
180
+ segment_lines.append(seg_var)
181
+
182
+ if segment_lines:
183
+ segs_str = ", ".join([f"Create({sv})" for sv in segment_lines])
184
+ lines.append(f" self.play({segs_str}, run_time=1.2)")
185
+
186
+ if derived_ids or segment_lines:
187
+ lines.append(" self.wait(0.5)")
188
+
189
+ # ── Phase 3: All labels ───────────────────────────────────────────────
190
+ all_labels_str = ", ".join([f"l_{pid}" for pid in coords])
191
+ lines.append(f" self.play(FadeIn(VGroup({all_labels_str})), run_time=0.8)")
192
+
193
+ # ── Circles phase ─────────────────────────────────────────────────────
194
+ for i in range(len(circles_meta)):
195
+ lines.append(f" self.play(Create(circle_{i}), run_time=1.5)")
196
+
197
+ # ── Lines & Rays phase ────────────────────────────────────────────────
198
+ if lines_meta or rays_meta:
199
+ lr_anims = []
200
+ for i in range(len(lines_meta)): lr_anims.append(f"Create(line_ext_{i})")
201
+ for i in range(len(rays_meta)): lr_anims.append(f"Create(ray_{i})")
202
+ lines.append(f" self.play({', '.join(lr_anims)}, run_time=1.5)")
203
+
204
+ lines.append(" self.wait(2)")
205
+
206
+ return "\n".join(lines)
207
+
208
+ def run_manim(self, script_content: str, job_id: str) -> str:
209
+ import subprocess
210
+ import os
211
+ import glob
212
+
213
+ script_file = f"{job_id}.py"
214
+ with open(script_file, "w") as f:
215
+ f.write(script_content)
216
+
217
+ try:
218
+ if os.getenv("MOCK_VIDEO") == "true":
219
+ logger.info(f"MOCK_VIDEO is true. Skipping Manim for job {job_id}")
220
+ # Create a dummy file if needed, or just return a path that exists
221
+ dummy_path = f"videos/{job_id}.mp4"
222
+ os.makedirs("videos", exist_ok=True)
223
+ with open(dummy_path, "wb") as f:
224
+ f.write(b"dummy video content")
225
+ return dummy_path
226
+
227
+ print(f"Running Manim for job {job_id}...")
228
+ result = subprocess.run(
229
+ ["manim", "-ql", "--media_dir", ".", "-o", f"{job_id}.mp4", script_file, "GeometryScene"],
230
+ capture_output=True,
231
+ text=True,
232
+ )
233
+ print(f"Manim STDOUT: {result.stdout}")
234
+ print(f"Manim STDERR: {result.stderr}")
235
+
236
+ for pattern in [f"**/videos/**/{job_id}.mp4", f"**/{job_id}*.mp4"]:
237
+ found = glob.glob(pattern, recursive=True)
238
+ if found:
239
+ print(f"Manim Success: Found {found[0]}")
240
+ return found[0]
241
+
242
+ print(f"Manim file not found for job {job_id}")
243
+ return ""
244
+ except Exception as e:
245
+ print(f"Manim Execution Error: {e}")
246
+ return ""
247
+ finally:
248
+ if os.path.exists(script_file):
249
+ os.remove(script_file)
agents/solver_agent.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import sympy as sp
4
+ from typing import Dict, Any, List
5
+ from app.llm_client import get_llm_client
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ class SolverAgent:
10
+ def __init__(self):
11
+ self.llm = get_llm_client()
12
+
13
+ async def solve(self, semantic_data: Dict[str, Any], engine_result: Dict[str, Any]) -> Dict[str, Any]:
14
+ """
15
+ Solves the geometric problem based on coordinates and the target question.
16
+ Returns a 'solution' dictionary with answer, steps, and symbolic_expression.
17
+ """
18
+ target_question = semantic_data.get("target_question")
19
+ if not target_question:
20
+ # If no question, just return an empty solution structure
21
+ return {
22
+ "answer": None,
23
+ "steps": [],
24
+ "symbolic_expression": None
25
+ }
26
+
27
+ logger.info(f"==[SolverAgent] Solving for: '{target_question}'==")
28
+
29
+ input_text = semantic_data.get("input_text", "")
30
+ coordinates = engine_result.get("coordinates", {})
31
+
32
+ # We provide the coordinates and semantic context to the LLM to help it reason.
33
+ # The LLM is tasked with generating the solution structure directly.
34
+
35
+ system_prompt = """
36
+ You are a Geometry Solver Agent. Your goal is to provide a step-by-step solution for a specific geometric question.
37
+
38
+ === DATA PROVIDED ===
39
+ 1. Target Question: The specific question to answer.
40
+ 2. Geometry Data: Entities and values extracted from the problem.
41
+ 3. Coordinates: Calculated coordinates for all points.
42
+
43
+ === REQUIREMENTS ===
44
+ - Provide the solution in the SAME LANGUAGE as the user's input.
45
+ - Use SymPy concepts if appropriate.
46
+ - Steps should be clear, concise, and logical.
47
+ - The final answer should be numerically or symbolically accurate based on the coordinates and geometric properties.
48
+ - For geometric proofs (e.g., "Is AB perpendicular to AC?"), explain the reasoning based on the data.
49
+
50
+ Output ONLY a JSON object with this structure:
51
+ {
52
+ "answer": "Chuỗi văn bản kết quả cuối cùng (kèm đơn vị nếu có)",
53
+ "steps": [
54
+ "Bước 1: ...",
55
+ "Bước 2: ...",
56
+ ...
57
+ ],
58
+ "symbolic_expression": "Biểu thức toán học rút gọn (LaTeX format optional)"
59
+ }
60
+ """
61
+
62
+ user_content = f"""
63
+ INPUT_TEXT: {input_text}
64
+ TARGET_QUESTION: {target_question}
65
+ SEMANTIC_DATA: {json.dumps(semantic_data, ensure_ascii=False)}
66
+ COORDINATES: {json.dumps(coordinates)}
67
+ """
68
+
69
+ logger.debug("[SolverAgent] Requesting solution from LLM...")
70
+ try:
71
+ raw = await self.llm.chat_completions_create(
72
+ messages=[
73
+ {"role": "system", "content": system_prompt},
74
+ {"role": "user", "content": user_content}
75
+ ],
76
+ response_format={"type": "json_object"}
77
+ )
78
+
79
+ clean_raw = raw.strip()
80
+ # Handle potential markdown code blocks if the response_format wasn't strictly honored
81
+ if clean_raw.startswith("```"):
82
+ import re
83
+ match = re.search(r"```(?:json)?\s*(.*?)\s*```", clean_raw, re.DOTALL)
84
+ if match:
85
+ clean_raw = match.group(1).strip()
86
+
87
+ try:
88
+ solution = json.loads(clean_raw)
89
+ except json.JSONDecodeError:
90
+ # Last resort: try to find anything between { and }
91
+ import re
92
+ json_match = re.search(r'(\{.*\})', clean_raw, re.DOTALL)
93
+ if json_match:
94
+ solution = json.loads(json_match.group(1))
95
+ else:
96
+ raise
97
+
98
+ logger.info("[SolverAgent] Solution generated successfully.")
99
+ return solution
100
+ except Exception as e:
101
+ logger.error(f"[SolverAgent] Error generating solution: {e}")
102
+ logger.debug(f"[SolverAgent] Raw LLM output was: \n{raw if 'raw' in locals() else 'N/A'}")
103
+ return {
104
+ "answer": "Không thể tính toán lời giải tại thời điểm này.",
105
+ "steps": ["Đã xảy ra lỗi trong quá trình xử lý lời giải."],
106
+ "symbolic_expression": None
107
+ }
agents/torch_ultralytics_compat.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch 2.6+ defaults weights_only=True; Ultralytics YOLO .pt checkpoints unpickle full nn graphs (trusted official weights)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import functools
6
+
7
+ _torch_load_patched = False
8
+
9
+
10
+ def allow_ultralytics_weights() -> None:
11
+ """
12
+ Official yolov8n.pt is a trusted checkpoint. PyTorch 2.6+ safe unpickling would require
13
+ allowlisting many torch.nn globals; loading with weights_only=False matches Ultralytics
14
+ upstream behavior for local .pt files.
15
+ """
16
+ global _torch_load_patched
17
+ if _torch_load_patched:
18
+ return
19
+ try:
20
+ import torch
21
+
22
+ _orig = torch.load
23
+
24
+ @functools.wraps(_orig)
25
+ def _load(*args, **kwargs):
26
+ if "weights_only" not in kwargs:
27
+ kwargs["weights_only"] = False
28
+ return _orig(*args, **kwargs)
29
+
30
+ torch.load = _load
31
+ _torch_load_patched = True
32
+ except Exception:
33
+ pass
app/dependencies.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import HTTPException, Header
2
+
3
+ from app.supabase_client import get_supabase, get_supabase_for_user_jwt
4
+
5
+
6
+ async def get_current_user_id(authorization: str = Header(...)):
7
+ """
8
+ Authenticate user using Supabase JWT.
9
+ Expected Header: Authorization: Bearer <token>
10
+ """
11
+ import os
12
+ if os.getenv("ALLOW_TEST_BYPASS") == "true" and authorization.startswith("Test "):
13
+ return authorization.split(" ")[1]
14
+
15
+ if not authorization or not authorization.startswith("Bearer "):
16
+ raise HTTPException(
17
+ status_code=401,
18
+ detail="Authorization header missing or invalid. Use 'Bearer <token>'",
19
+ )
20
+
21
+ token = authorization.split(" ")[1]
22
+ supabase = get_supabase()
23
+
24
+ try:
25
+ user_response = supabase.auth.get_user(token)
26
+ if not user_response or not user_response.user:
27
+ raise HTTPException(status_code=401, detail="Invalid session or token.")
28
+
29
+ return user_response.user.id
30
+ except HTTPException:
31
+ raise
32
+ except Exception as e:
33
+ raise HTTPException(status_code=401, detail=f"Authentication failed: {str(e)}")
34
+
35
+
36
+ async def get_authenticated_supabase(authorization: str = Header(...)):
37
+ """
38
+ Supabase client that carries the user's JWT (anon key + Authorization header).
39
+ Use for routes that should respect Row Level Security; pair with app logic as needed.
40
+ """
41
+ if not authorization or not authorization.startswith("Bearer "):
42
+ raise HTTPException(
43
+ status_code=401,
44
+ detail="Authorization header missing or invalid. Use 'Bearer <token>'",
45
+ )
46
+
47
+ token = authorization.split(" ")[1]
48
+ supabase = get_supabase()
49
+
50
+ try:
51
+ user_response = supabase.auth.get_user(token)
52
+ if not user_response or not user_response.user:
53
+ raise HTTPException(status_code=401, detail="Invalid session or token.")
54
+ except HTTPException:
55
+ raise
56
+ except Exception as e:
57
+ raise HTTPException(status_code=401, detail=f"Authentication failed: {str(e)}")
58
+
59
+ try:
60
+ return get_supabase_for_user_jwt(token)
61
+ except RuntimeError as e:
62
+ raise HTTPException(status_code=503, detail=str(e))
app/errors.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Map exceptions to short, user-visible messages (avoid leaking HTML bodies from 404 proxies)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def _looks_like_html(text: str) -> bool:
11
+ t = text.lstrip()[:500].lower()
12
+ return t.startswith("<!doctype") or t.startswith("<html") or "<html" in t[:200]
13
+
14
+
15
+ def format_error_for_user(exc: BaseException) -> str:
16
+ """
17
+ Produce a safe message for chat/UI. Full detail stays in server logs via logger.exception.
18
+ """
19
+ # httpx: wrong URL often returns 404 HTML; don't show body
20
+ try:
21
+ import httpx
22
+
23
+ if isinstance(exc, httpx.HTTPStatusError):
24
+ req = exc.request
25
+ code = exc.response.status_code
26
+ url_hint = ""
27
+ try:
28
+ url_hint = str(req.url.host) if req and req.url else ""
29
+ except Exception:
30
+ pass
31
+ logger.warning(
32
+ "HTTPStatusError %s for %s (response not shown to user)",
33
+ code,
34
+ url_hint or "?",
35
+ )
36
+ return (
37
+ "Kiểm tra URL API, khóa bí mật và biến môi trường (OpenRouter/Supabase/Redis)."
38
+ )
39
+
40
+ if isinstance(exc, httpx.RequestError):
41
+ return "Không kết nối được tới dịch vụ ngoài (mạng hoặc URL sai)."
42
+ except ImportError:
43
+ pass
44
+
45
+ raw = str(exc).strip()
46
+ if not raw:
47
+ return "Đã xảy ra lỗi không xác định."
48
+
49
+ if _looks_like_html(raw):
50
+ logger.warning("Suppressed HTML error body from user-facing message")
51
+ return (
52
+ "Dịch vụ trả về trang lỗi (thường là URL API sai hoặc endpoint không tồn tại — HTTP 404). "
53
+ "Kiểm tra OPENROUTER_MODEL và khóa API trên server."
54
+ )
55
+
56
+ if len(raw) > 800:
57
+ return raw[:800] + "…"
58
+
59
+ return raw
app/llm_client.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import asyncio
4
+ import logging
5
+ from openai import AsyncOpenAI
6
+ from typing import List, Dict, Any, Optional
7
+ from app.url_utils import openai_compatible_api_key, sanitize_env
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class MultiLayerLLMClient:
12
+ def __init__(self):
13
+ # 1. OpenRouter (Primary with Rotation)
14
+ self.openrouter_model = os.getenv("OPENROUTER_MODEL", "google/gemini-2.0-flash-001")
15
+ self.keys = []
16
+ for i in range(1, 6):
17
+ key = os.getenv(f"OPENROUTER_API_KEY_{i}")
18
+ if key:
19
+ self.keys.append(key)
20
+
21
+ if not self.keys:
22
+ # Fallback to single key if exists (legacy)
23
+ single_key = os.getenv("OPENROUTER_API_KEY")
24
+ if single_key:
25
+ self.keys.append(single_key)
26
+
27
+ self.clients = [
28
+ AsyncOpenAI(
29
+ api_key=openai_compatible_api_key(k),
30
+ base_url="https://openrouter.ai/api/v1",
31
+ timeout=120.0,
32
+ default_headers={
33
+ "HTTP-Referer": "https://mathsolver.ai",
34
+ "X-Title": "MathSolver Backend",
35
+ }
36
+ ) for k in self.keys
37
+ ]
38
+ self.current_index = 0
39
+
40
+ async def chat_completions_create(
41
+ self,
42
+ messages: List[Dict[str, str]],
43
+ response_format: Optional[Dict[str, str]] = None,
44
+ **kwargs
45
+ ) -> str:
46
+ """
47
+ Rotates through OpenRouter API keys on every attempt (success or failure).
48
+ Tries up to 2 retries (total 3 attempts), with 0s delay.
49
+ """
50
+ MAX_RETRIES = 2
51
+ RETRY_DELAY = 0 # seconds
52
+
53
+ if not self.clients:
54
+ logger.error("[LLM] No OpenRouter API keys found.")
55
+ raise ValueError("No API keys configured.")
56
+
57
+ for attempt in range(1, MAX_RETRIES + 2): # Up to 3 attempts total
58
+ client = self.clients[self.current_index]
59
+ key_id = self.current_index + 1
60
+
61
+ try:
62
+ logger.info(f"[LLM] Attempt {attempt}/{MAX_RETRIES + 1} using Key #{key_id} ({self.openrouter_model})...")
63
+ response = await client.chat.completions.create(
64
+ model=self.openrouter_model,
65
+ messages=messages,
66
+ response_format=response_format,
67
+ **kwargs
68
+ )
69
+
70
+ if not response or not getattr(response, "choices", None):
71
+ raise ValueError("Invalid response structure from OpenRouter")
72
+
73
+ content = response.choices[0].message.content
74
+ if content:
75
+ logger.info(f"[LLM] SUCCESS on attempt {attempt} (Key #{key_id}).")
76
+ # Luôn xoay sang key tiếp theo sau khi thành công để chuẩn bị cho request tới
77
+ self.current_index = (self.current_index + 1) % len(self.clients)
78
+ return content
79
+
80
+ raise ValueError("Empty content from OpenRouter")
81
+
82
+ except Exception as e:
83
+ err_msg = f"{type(e).__name__}: {str(e)}"
84
+ logger.warning(f"[LLM] FAILED on attempt {attempt} (Key #{key_id}): {err_msg}")
85
+
86
+ # Xoay key kể cả khi thất bại để attempt tiếp theo dùng key khác
87
+ old_index = self.current_index
88
+ self.current_index = (self.current_index + 1) % len(self.clients)
89
+
90
+ if attempt <= MAX_RETRIES:
91
+ logger.info(f"[LLM] Switching from Key #{old_index + 1} to #{self.current_index + 1}. Retrying in {RETRY_DELAY}s...")
92
+ await asyncio.sleep(RETRY_DELAY)
93
+ else:
94
+ logger.error(f"[LLM] FINAL FAILURE after {attempt} attempts.")
95
+ raise e
96
+
97
+ # Global instance for easy reuse (singleton-ish)
98
+ _llm_client = None
99
+
100
+ def get_llm_client() -> MultiLayerLLMClient:
101
+ global _llm_client
102
+ if _llm_client is None:
103
+ _llm_client = MultiLayerLLMClient()
104
+ return _llm_client
app/logging_setup.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Logging theo một biến LOG_LEVEL: debug | info | warning | error."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import os
7
+ from typing import Final
8
+
9
+ _SETUP_DONE = False
10
+
11
+ PIPELINE_LOGGER_NAME: Final = "app.pipeline"
12
+ CACHE_LOGGER_NAME: Final = "app.cache"
13
+ STEPS_LOGGER_NAME: Final = "app.steps"
14
+ ACCESS_LOGGER_NAME: Final = "app.access"
15
+
16
+
17
+ def _normalize_level() -> str:
18
+ raw = os.getenv("LOG_LEVEL", "info").strip().lower()
19
+ if raw in ("debug", "info", "warning", "error"):
20
+ return raw
21
+ return "info"
22
+
23
+
24
+ def setup_application_logging() -> None:
25
+ """Idempotent; gọi khi khởi động process (uvicorn, celery, worker_health)."""
26
+ global _SETUP_DONE
27
+ if _SETUP_DONE:
28
+ return
29
+ _SETUP_DONE = True
30
+
31
+ mode = _normalize_level()
32
+
33
+ level_map = {
34
+ "debug": logging.DEBUG,
35
+ "info": logging.INFO,
36
+ "warning": logging.WARNING,
37
+ "error": logging.ERROR,
38
+ }
39
+ root_level = level_map[mode]
40
+
41
+ fmt_named = "%(asctime)s | %(levelname)-8s | %(name)s | %(message)s"
42
+ fmt_short = "%(asctime)s | %(levelname)-8s | %(message)s"
43
+
44
+ logging.basicConfig(
45
+ level=root_level,
46
+ format=fmt_named if mode == "debug" else fmt_short,
47
+ datefmt="%H:%M:%S",
48
+ force=True,
49
+ )
50
+
51
+ logging.getLogger("httpx").setLevel(logging.WARNING)
52
+ logging.getLogger("httpcore").setLevel(logging.WARNING)
53
+ logging.getLogger("openai").setLevel(logging.WARNING)
54
+ logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
55
+ logging.getLogger("uvicorn.error").setLevel(logging.INFO)
56
+ # HTTP/2 stack (httpx/httpcore) — khi LOG_LEVEL=debug root=DEBUG sẽ tràn log hpack; không cần cho debug app
57
+ for _name in ("hpack", "h2", "hyperframe", "urllib3"):
58
+ logging.getLogger(_name).setLevel(logging.WARNING)
59
+
60
+ if mode == "debug":
61
+ logging.getLogger("agents").setLevel(logging.DEBUG)
62
+ logging.getLogger("solver").setLevel(logging.DEBUG)
63
+ logging.getLogger("app").setLevel(logging.DEBUG)
64
+ logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.DEBUG)
65
+ logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.DEBUG)
66
+ logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.INFO)
67
+ logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.INFO)
68
+ logging.getLogger("app.main").setLevel(logging.INFO)
69
+ logging.getLogger("worker").setLevel(logging.INFO)
70
+ elif mode == "info":
71
+ # Chỉ HTTP access (app.access) + startup; ẩn chi tiết agents/orchestrator/pipeline SUCCESS
72
+ logging.getLogger("agents").setLevel(logging.INFO)
73
+ logging.getLogger("solver").setLevel(logging.WARNING)
74
+ logging.getLogger("app").setLevel(logging.INFO)
75
+ logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.WARNING)
76
+ logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.WARNING)
77
+ logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.WARNING)
78
+ logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.INFO)
79
+ logging.getLogger("app.main").setLevel(logging.INFO)
80
+ logging.getLogger("worker").setLevel(logging.WARNING)
81
+ elif mode == "warning":
82
+ logging.getLogger("agents").setLevel(logging.WARNING)
83
+ logging.getLogger("solver").setLevel(logging.WARNING)
84
+ logging.getLogger("app.routers").setLevel(logging.WARNING)
85
+ logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.WARNING)
86
+ logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.WARNING)
87
+ logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.WARNING)
88
+ logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.WARNING)
89
+ logging.getLogger("app.main").setLevel(logging.WARNING)
90
+ logging.getLogger("worker").setLevel(logging.WARNING)
91
+ else: # error
92
+ logging.getLogger("agents").setLevel(logging.ERROR)
93
+ logging.getLogger("solver").setLevel(logging.ERROR)
94
+ logging.getLogger("app.routers").setLevel(logging.ERROR)
95
+ logging.getLogger(CACHE_LOGGER_NAME).setLevel(logging.ERROR)
96
+ logging.getLogger(STEPS_LOGGER_NAME).setLevel(logging.ERROR)
97
+ logging.getLogger(PIPELINE_LOGGER_NAME).setLevel(logging.ERROR)
98
+ logging.getLogger(ACCESS_LOGGER_NAME).setLevel(logging.ERROR)
99
+ logging.getLogger("app.main").setLevel(logging.ERROR)
100
+ logging.getLogger("worker").setLevel(logging.ERROR)
101
+
102
+ logging.getLogger(__name__).debug(
103
+ "LOG_LEVEL=%s root=%s", mode, logging.getLevelName(root_level)
104
+ )
105
+
106
+
107
+ def get_log_level() -> str:
108
+ return _normalize_level()
109
+
110
+
111
+ def is_debug_level() -> bool:
112
+ return _normalize_level() == "debug"
app/logutil.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """log_step (debug), pipeline (debug), access log ở middleware."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import logging
7
+ import os
8
+ from typing import Any
9
+
10
+ from app.logging_setup import PIPELINE_LOGGER_NAME, STEPS_LOGGER_NAME
11
+
12
+ _pipeline = logging.getLogger(PIPELINE_LOGGER_NAME)
13
+ _steps = logging.getLogger(STEPS_LOGGER_NAME)
14
+
15
+
16
+ def is_debug_mode() -> bool:
17
+ """Chi tiết từng bước chỉ khi LOG_LEVEL=debug."""
18
+ return os.getenv("LOG_LEVEL", "info").strip().lower() == "debug"
19
+
20
+
21
+ def _truncate(val: Any, max_len: int = 2000) -> Any:
22
+ if val is None:
23
+ return None
24
+ if isinstance(val, (int, float, bool)):
25
+ return val
26
+ s = str(val)
27
+ if len(s) > max_len:
28
+ return s[:max_len] + f"... (+{len(s) - max_len} chars)"
29
+ return s
30
+
31
+
32
+ def log_step(step: str, **fields: Any) -> None:
33
+ """Chỉ khi LOG_LEVEL=debug: DB / cache / orchestrator."""
34
+ if not is_debug_mode():
35
+ return
36
+ safe = {k: _truncate(v) for k, v in fields.items()}
37
+ try:
38
+ payload = json.dumps(safe, ensure_ascii=False, default=str)
39
+ except Exception:
40
+ payload = str(safe)
41
+ _steps.debug("[step:%s] %s", step, payload)
42
+
43
+
44
+ def log_pipeline_success(operation: str, **fields: Any) -> None:
45
+ """Chỉ hiện khi debug (pipeline SUCCESS không dùng ở info — đã có app.access)."""
46
+ if not is_debug_mode():
47
+ return
48
+ safe = {k: _truncate(v, 500) for k, v in fields.items()}
49
+ _pipeline.info(
50
+ "SUCCESS %s %s",
51
+ operation,
52
+ json.dumps(safe, ensure_ascii=False, default=str),
53
+ )
54
+
55
+
56
+ def log_pipeline_failure(operation: str, error: str | None = None, **fields: Any) -> None:
57
+ """Thất bại pipeline: luôn dùng WARNING để vẫn thấy khi LOG_LEVEL=warning."""
58
+ if is_debug_mode():
59
+ safe = {k: _truncate(v, 500) for k, v in fields.items()}
60
+ _pipeline.warning(
61
+ "FAIL %s err=%s %s",
62
+ operation,
63
+ _truncate(error, 300),
64
+ json.dumps(safe, ensure_ascii=False, default=str),
65
+ )
66
+ else:
67
+ _pipeline.warning("FAIL %s", operation)
app/main.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import time
6
+ import uuid
7
+ import warnings
8
+
9
+ from dotenv import load_dotenv
10
+ from fastapi import FastAPI, File, HTTPException, UploadFile
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+ from starlette.requests import Request
13
+
14
+ load_dotenv()
15
+
16
+ from app.runtime_env import apply_runtime_env_defaults
17
+
18
+ apply_runtime_env_defaults()
19
+
20
+ os.environ["NO_ALBUMENTATIONS_UPDATE"] = "1"
21
+ warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
22
+ warnings.filterwarnings("ignore", category=UserWarning, module="albumentations")
23
+
24
+ from app.logging_setup import ACCESS_LOGGER_NAME, get_log_level, setup_application_logging
25
+
26
+ setup_application_logging()
27
+
28
+ # Routers (after logging)
29
+ from app.routers import auth, sessions, solve
30
+ from agents.ocr_agent import OCRAgent
31
+ from app.routers.solve import get_orchestrator
32
+ from app.supabase_client import get_supabase
33
+ from app.websocket_manager import register_websocket_routes
34
+
35
+ logger = logging.getLogger("app.main")
36
+ _access = logging.getLogger(ACCESS_LOGGER_NAME)
37
+
38
+ app = FastAPI(title="Visual Math Solver API v4.0")
39
+
40
+
41
+ @app.middleware("http")
42
+ async def access_log_middleware(request: Request, call_next):
43
+ """LOG_LEVEL=info/debug: mọi request; warning: chỉ 4xx/5xx; error: chỉ 4xx/5xx ở mức error."""
44
+ start = time.perf_counter()
45
+ response = await call_next(request)
46
+ ms = (time.perf_counter() - start) * 1000
47
+ mode = get_log_level()
48
+ method = request.method
49
+ path = request.url.path
50
+ status = response.status_code
51
+
52
+ if mode in ("debug", "info"):
53
+ _access.info("%s %s -> %s (%.0fms)", method, path, status, ms)
54
+ elif mode == "warning":
55
+ if status >= 500:
56
+ _access.error("%s %s -> %s (%.0fms)", method, path, status, ms)
57
+ elif status >= 400:
58
+ _access.warning("%s %s -> %s (%.0fms)", method, path, status, ms)
59
+ elif mode == "error":
60
+ if status >= 400:
61
+ _access.error("%s %s -> %s", method, path, status)
62
+
63
+ return response
64
+
65
+
66
+ from worker.celery_app import BROKER_URL
67
+
68
+ _broker_tail = BROKER_URL.split("@")[-1] if "@" in BROKER_URL else BROKER_URL
69
+ if get_log_level() in ("debug", "info"):
70
+ logger.info("App starting LOG_LEVEL=%s | Redis: %s", get_log_level(), _broker_tail)
71
+ else:
72
+ logger.warning(
73
+ "App starting LOG_LEVEL=%s | Redis: %s", get_log_level(), _broker_tail
74
+ )
75
+
76
+ app.add_middleware(
77
+ CORSMiddleware,
78
+ allow_origins=["*"],
79
+ allow_credentials=True,
80
+ allow_methods=["*"],
81
+ allow_headers=["*"],
82
+ )
83
+
84
+ app.include_router(auth.router)
85
+ app.include_router(sessions.router)
86
+ app.include_router(solve.router)
87
+
88
+ register_websocket_routes(app)
89
+
90
+
91
+ def get_ocr_agent() -> OCRAgent:
92
+ """Same OCR instance as the solve pipeline (no duplicate model load)."""
93
+ return get_orchestrator().ocr_agent
94
+
95
+
96
+ supabase_client = get_supabase()
97
+
98
+
99
+ @app.get("/")
100
+ def read_root():
101
+ return {"message": "Visual Math Solver API v4.0 is running"}
102
+
103
+
104
+ @app.post("/api/v1/ocr")
105
+ async def upload_ocr(file: UploadFile = File(...)):
106
+ """Legacy OCR endpoint (retained for now as it's stateless)"""
107
+ temp_path = f"temp_{uuid.uuid4()}.png"
108
+ with open(temp_path, "wb") as buffer:
109
+ buffer.write(await file.read())
110
+
111
+ try:
112
+ text = await get_ocr_agent().process_image(temp_path)
113
+ return {"text": text}
114
+ finally:
115
+ if os.path.exists(temp_path):
116
+ os.remove(temp_path)
117
+
118
+
119
+ @app.get("/api/v1/solve/{job_id}")
120
+ async def get_job_status(job_id: str):
121
+ """Retrieve job status (can be used for polling if WS fails)"""
122
+ response = supabase_client.table("jobs").select("*").eq("id", job_id).execute()
123
+ if not response.data:
124
+ raise HTTPException(status_code=404, detail="Job not found")
125
+ return response.data[0]
app/models/schemas.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, EmailStr, field_validator
2
+ from typing import Optional, List, Any, Dict
3
+ from datetime import datetime
4
+ import uuid
5
+
6
+ from app.url_utils import sanitize_url
7
+
8
+ # --- Auth Schemas ---
9
+ class UserProfile(BaseModel):
10
+ id: uuid.UUID
11
+ display_name: Optional[str] = None
12
+ avatar_url: Optional[str] = None
13
+ created_at: datetime
14
+
15
+ class User(BaseModel):
16
+ id: uuid.UUID
17
+ email: EmailStr
18
+
19
+ # --- Session Schemas ---
20
+ class SessionBase(BaseModel):
21
+ title: str = "Bài toán mới"
22
+
23
+ class SessionCreate(SessionBase):
24
+ pass
25
+
26
+ class Session(SessionBase):
27
+ id: uuid.UUID
28
+ user_id: uuid.UUID
29
+ created_at: datetime
30
+ updated_at: datetime
31
+
32
+ class Config:
33
+ from_attributes = True
34
+
35
+ # --- Message Schemas ---
36
+ class MessageBase(BaseModel):
37
+ role: str
38
+ type: str = "text"
39
+ content: str
40
+ metadata: Dict[str, Any] = {}
41
+
42
+ class MessageCreate(MessageBase):
43
+ session_id: uuid.UUID
44
+
45
+ class Message(MessageBase):
46
+ id: uuid.UUID
47
+ session_id: uuid.UUID
48
+ created_at: datetime
49
+
50
+ class Config:
51
+ from_attributes = True
52
+
53
+ # --- Solve Job Schemas ---
54
+ class SolveRequest(BaseModel):
55
+ text: str
56
+ image_url: Optional[str] = None
57
+ request_video: bool = False
58
+
59
+ @field_validator("image_url", mode="before")
60
+ @classmethod
61
+ def _clean_image_url(cls, v):
62
+ return sanitize_url(v) if v is not None else None
63
+
64
+ class SolveResponse(BaseModel):
65
+ job_id: str
66
+ status: str
app/routers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import auth, sessions, solve
app/routers/auth.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, Depends, HTTPException
2
+ from app.dependencies import get_current_user_id
3
+ from app.supabase_client import get_supabase
4
+ from app.models.schemas import UserProfile
5
+ import uuid
6
+
7
+ router = APIRouter(prefix="/api/v1/auth", tags=["Auth"])
8
+
9
+ @router.get("/me")
10
+ async def get_me(user_id=Depends(get_current_user_id)):
11
+ """获取当前登录用户的信息 (Retrieve current user profile)"""
12
+ supabase = get_supabase()
13
+ res = supabase.table("profiles").select("*").eq("id", user_id).execute()
14
+ if not res.data:
15
+ raise HTTPException(status_code=404, detail="Profile not found.")
16
+ return res.data[0]
17
+
18
+ @router.patch("/me")
19
+ async def update_me(data: dict, user_id=Depends(get_current_user_id)):
20
+ """Cập nhật profile hiện tại (Update current profile)"""
21
+ supabase = get_supabase()
22
+ res = supabase.table("profiles").update(data).eq("id", user_id).execute()
23
+ return res.data[0]
app/routers/sessions.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import List
4
+
5
+ from fastapi import APIRouter, Depends, HTTPException
6
+
7
+ from app.dependencies import get_current_user_id
8
+ from app.logutil import log_step
9
+ from app.session_cache import (
10
+ get_sessions_list_cached,
11
+ invalidate_for_user,
12
+ invalidate_session_owner,
13
+ session_owned_by_user,
14
+ )
15
+ from app.supabase_client import get_supabase
16
+
17
+ router = APIRouter(prefix="/api/v1/sessions", tags=["Sessions"])
18
+
19
+
20
+ @router.get("", response_model=List[dict])
21
+ async def list_sessions(user_id=Depends(get_current_user_id)):
22
+ """Danh sách các phiên chat của người dùng (List user's chat sessions)"""
23
+ supabase = get_supabase()
24
+
25
+ def fetch() -> list:
26
+ res = (
27
+ supabase.table("sessions")
28
+ .select("*")
29
+ .eq("user_id", user_id)
30
+ .order("updated_at", desc=True)
31
+ .execute()
32
+ )
33
+ log_step("db_select", table="sessions", op="list", user_id=str(user_id))
34
+ return res.data
35
+
36
+ return get_sessions_list_cached(str(user_id), fetch)
37
+
38
+
39
+ @router.post("", response_model=dict)
40
+ async def create_session(user_id=Depends(get_current_user_id)):
41
+ """Tạo một phiên chat mới (Create a new chat session)"""
42
+ supabase = get_supabase()
43
+ res = supabase.table("sessions").insert(
44
+ {"user_id": user_id, "title": "Bài toán mới"}
45
+ ).execute()
46
+ log_step("db_insert", table="sessions", op="create")
47
+ invalidate_for_user(str(user_id))
48
+ return res.data[0]
49
+
50
+
51
+ @router.get("/{session_id}/messages", response_model=List[dict])
52
+ async def get_session_messages(session_id: str, user_id=Depends(get_current_user_id)):
53
+ """Lấy toàn bộ lịch sử tin nhắn của một phiên (Get chat history for a session)"""
54
+ supabase = get_supabase()
55
+
56
+ def owns() -> bool:
57
+ res = (
58
+ supabase.table("sessions")
59
+ .select("id")
60
+ .eq("id", session_id)
61
+ .eq("user_id", user_id)
62
+ .execute()
63
+ )
64
+ log_step("db_select", table="sessions", op="owner_check", session_id=session_id)
65
+ return bool(res.data)
66
+
67
+ if not session_owned_by_user(session_id, str(user_id), owns):
68
+ raise HTTPException(
69
+ status_code=403, detail="Forbidden: You do not own this session."
70
+ )
71
+
72
+ res = (
73
+ supabase.table("messages")
74
+ .select("*")
75
+ .eq("session_id", session_id)
76
+ .order("created_at", desc=False)
77
+ .execute()
78
+ )
79
+ log_step("db_select", table="messages", op="list", session_id=session_id)
80
+ return res.data
81
+
82
+
83
+ @router.delete("/{session_id}")
84
+ async def delete_session(session_id: str, user_id=Depends(get_current_user_id)):
85
+ """Xóa một phiên chat (Delete a chat session)"""
86
+ supabase = get_supabase()
87
+
88
+ def owns() -> bool:
89
+ res = (
90
+ supabase.table("sessions")
91
+ .select("id")
92
+ .eq("id", session_id)
93
+ .eq("user_id", user_id)
94
+ .execute()
95
+ )
96
+ return bool(res.data)
97
+
98
+ if not session_owned_by_user(session_id, str(user_id), owns):
99
+ raise HTTPException(
100
+ status_code=403, detail="Forbidden: You do not own this session."
101
+ )
102
+
103
+ # jobs.session_id FK must be cleared before sessions row
104
+ supabase.table("jobs").delete().eq("session_id", session_id).eq("user_id", user_id).execute()
105
+ log_step("db_delete", table="jobs", op="by_session", session_id=session_id)
106
+ supabase.table("messages").delete().eq("session_id", session_id).execute()
107
+ log_step("db_delete", table="messages", op="by_session", session_id=session_id)
108
+ res = (
109
+ supabase.table("sessions")
110
+ .delete()
111
+ .eq("id", session_id)
112
+ .eq("user_id", user_id)
113
+ .execute()
114
+ )
115
+ log_step("db_delete", table="sessions", session_id=session_id)
116
+ invalidate_for_user(str(user_id))
117
+ invalidate_session_owner(session_id, str(user_id))
118
+ return {"status": "ok", "deleted_id": session_id}
119
+
120
+
121
+ @router.patch("/{session_id}/title")
122
+ async def update_session_title(title: str, session_id: str, user_id=Depends(get_current_user_id)):
123
+ """Cập nhật tiêu đề phiên chat (Rename a chat session)"""
124
+ supabase = get_supabase()
125
+ res = (
126
+ supabase.table("sessions")
127
+ .update({"title": title})
128
+ .eq("id", session_id)
129
+ .eq("user_id", user_id)
130
+ .execute()
131
+ )
132
+ log_step("db_update", table="sessions", op="title", session_id=session_id)
133
+ invalidate_for_user(str(user_id))
134
+ return res.data[0]
135
+
136
+
137
+ @router.get("/{session_id}/assets", response_model=List[dict])
138
+ async def get_session_assets(session_id: str, user_id=Depends(get_current_user_id)):
139
+ """Lấy danh sách video đã render trong session (Get versioned assets for a session)"""
140
+ supabase = get_supabase()
141
+
142
+ def owns() -> bool:
143
+ res = (
144
+ supabase.table("sessions")
145
+ .select("id")
146
+ .eq("id", session_id)
147
+ .eq("user_id", user_id)
148
+ .execute()
149
+ )
150
+ return bool(res.data)
151
+
152
+ if not session_owned_by_user(session_id, str(user_id), owns):
153
+ raise HTTPException(
154
+ status_code=403, detail="Forbidden: You do not own this session."
155
+ )
156
+
157
+ res = (
158
+ supabase.table("session_assets")
159
+ .select("*")
160
+ .eq("session_id", session_id)
161
+ .order("version", desc=True)
162
+ .execute()
163
+ )
164
+ log_step("db_select", table="session_assets", op="list", session_id=session_id)
165
+ return res.data
app/routers/solve.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import uuid
5
+
6
+ from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
7
+
8
+ from agents.orchestrator import Orchestrator
9
+ from app.dependencies import get_current_user_id
10
+ from app.errors import format_error_for_user
11
+ from app.logutil import log_pipeline_failure, log_pipeline_success, log_step
12
+ from app.models.schemas import SolveRequest, SolveResponse
13
+ from app.session_cache import invalidate_for_user, session_owned_by_user
14
+ from app.supabase_client import get_supabase
15
+
16
+ logger = logging.getLogger(__name__)
17
+ router = APIRouter(prefix="/api/v1/sessions", tags=["Solve"])
18
+
19
+ # Eager init: all agents and models load at import time (also run in Docker build via scripts/prewarm_models.py).
20
+ ORCHESTRATOR = Orchestrator()
21
+
22
+
23
+ def get_orchestrator() -> Orchestrator:
24
+ return ORCHESTRATOR
25
+
26
+
27
+ @router.post("/{session_id}/solve", response_model=SolveResponse)
28
+ async def solve_problem(
29
+ session_id: str,
30
+ request: SolveRequest,
31
+ background_tasks: BackgroundTasks,
32
+ user_id=Depends(get_current_user_id),
33
+ ):
34
+ """
35
+ Gửi câu hỏi giải toán trong một session (Submit geometry problem in a session).
36
+ Lưu câu hỏi vào history và bắt đầu tiến trình giải.
37
+ """
38
+ supabase = get_supabase()
39
+ uid = str(user_id)
40
+
41
+ def owns() -> bool:
42
+ res = (
43
+ supabase.table("sessions")
44
+ .select("id")
45
+ .eq("id", session_id)
46
+ .eq("user_id", user_id)
47
+ .execute()
48
+ )
49
+ log_step("db_select", table="sessions", op="owner_check", session_id=session_id)
50
+ return bool(res.data)
51
+
52
+ if not session_owned_by_user(session_id, uid, owns):
53
+ log_pipeline_failure("solve_request", error="forbidden", session_id=session_id)
54
+ raise HTTPException(
55
+ status_code=403, detail="Forbidden: You do not own this session."
56
+ )
57
+
58
+ # NEW: Giới hạn 5 queries mỗi session
59
+ msg_count_res = (
60
+ supabase.table("messages")
61
+ .select("id", count="exact")
62
+ .eq("session_id", session_id)
63
+ .eq("role", "user")
64
+ .execute()
65
+ )
66
+ current_count = msg_count_res.count if msg_count_res.count is not None else 0
67
+ import os
68
+ if current_count >= 5 and os.getenv("ALLOW_TEST_BYPASS") != "true":
69
+ raise HTTPException(
70
+ status_code=400,
71
+ detail="Bạn đã đạt giới hạn 5 câu hỏi cho phiên này. (Session limit reached: 5/5)"
72
+ )
73
+
74
+ supabase.table("messages").insert(
75
+ {
76
+ "session_id": session_id,
77
+ "role": "user",
78
+ "type": "text",
79
+ "content": request.text,
80
+ "metadata": {"image_url": request.image_url} if request.image_url else {},
81
+ }
82
+ ).execute()
83
+ log_step("db_insert", table="messages", op="user_message", session_id=session_id)
84
+
85
+ job_id = str(uuid.uuid4())
86
+ supabase.table("jobs").insert(
87
+ {
88
+ "id": job_id,
89
+ "user_id": user_id,
90
+ "session_id": session_id,
91
+ "status": "processing",
92
+ "input_text": request.text,
93
+ }
94
+ ).execute()
95
+ log_step("db_insert", table="jobs", job_id=job_id)
96
+
97
+ background_tasks.add_task(process_session_job, job_id, session_id, request, user_id)
98
+
99
+ title_check = supabase.table("sessions").select("title").eq("id", session_id).execute()
100
+ if title_check.data and title_check.data[0]["title"] == "Bài toán mới":
101
+ new_title = request.text[:50] + ("..." if len(request.text) > 50 else "")
102
+ supabase.table("sessions").update({"title": new_title}).eq("id", session_id).execute()
103
+ log_step("db_update", table="sessions", op="title_from_first_message")
104
+ invalidate_for_user(uid)
105
+
106
+ log_pipeline_success("solve_accepted", job_id=job_id, session_id=session_id)
107
+ return SolveResponse(job_id=job_id, status="processing")
108
+
109
+
110
+ async def process_session_job(
111
+ job_id: str, session_id: str, request: SolveRequest, user_id: str
112
+ ):
113
+ """Tiến trình giải toán ngầm, cập nhật cả bảng jobs và bảng messages (history)."""
114
+ from app.websocket_manager import notify_status
115
+
116
+ async def status_update(status: str):
117
+ await notify_status(job_id, {"status": status})
118
+
119
+ supabase = get_supabase()
120
+ try:
121
+ # Fetch full history for the session
122
+ history_res = (
123
+ supabase.table("messages")
124
+ .select("*")
125
+ .eq("session_id", session_id)
126
+ .order("created_at", desc=False)
127
+ .execute()
128
+ )
129
+ history = history_res.data if history_res.data else []
130
+
131
+ result = await get_orchestrator().run(
132
+ request.text,
133
+ request.image_url,
134
+ job_id=job_id,
135
+ session_id=session_id,
136
+ status_callback=status_update,
137
+ request_video=request.request_video,
138
+ history=history,
139
+ )
140
+
141
+ status = result.get("status", "error") if "error" not in result else "error"
142
+
143
+ supabase.table("jobs").update({"status": status, "result": result}).eq(
144
+ "id", job_id
145
+ ).execute()
146
+ log_step("db_update", table="jobs", job_id=job_id, status=status)
147
+
148
+ if status != "rendering_queued":
149
+ supabase.table("messages").insert(
150
+ {
151
+ "session_id": session_id,
152
+ "role": "assistant",
153
+ "type": "analysis" if "error" not in result else "error",
154
+ "content": (
155
+ result.get("semantic_analysis", "Đã có lỗi xảy ra.")
156
+ if "error" not in result
157
+ else result["error"]
158
+ ),
159
+ "metadata": {
160
+ "job_id": job_id,
161
+ "coordinates": result.get("coordinates"),
162
+ "geometry_dsl": result.get("geometry_dsl"),
163
+ "polygon_order": result.get("polygon_order", []),
164
+ "drawing_phases": result.get("drawing_phases", []),
165
+ "circles": result.get("circles", []),
166
+ "lines": result.get("lines", []),
167
+ "rays": result.get("rays", []),
168
+ "video_url": result.get("video_url"),
169
+ "solution": result.get("solution"),
170
+ },
171
+ }
172
+ ).execute()
173
+ log_step("db_insert", table="messages", op="assistant", job_id=job_id)
174
+
175
+ await notify_status(job_id, {"status": status, "result": result})
176
+
177
+ if "error" in result:
178
+ log_pipeline_failure(
179
+ "solve_job", error=result.get("error"), job_id=job_id, session_id=session_id
180
+ )
181
+ else:
182
+ log_pipeline_success(
183
+ "solve_job", job_id=job_id, session_id=session_id, status=status
184
+ )
185
+
186
+ except Exception as e:
187
+ logger.exception("Error processing session job %s", job_id)
188
+ safe = format_error_for_user(e)
189
+ supabase.table("jobs").update(
190
+ {"status": "error", "result": {"error": safe}}
191
+ ).eq("id", job_id).execute()
192
+
193
+ supabase.table("messages").insert(
194
+ {
195
+ "session_id": session_id,
196
+ "role": "assistant",
197
+ "type": "error",
198
+ "content": f"Lỗi hệ thống: {safe}",
199
+ "metadata": {"job_id": job_id},
200
+ }
201
+ ).execute()
202
+
203
+ await notify_status(job_id, {"status": "error", "message": safe})
204
+ log_pipeline_failure("solve_job", error=safe, job_id=job_id, session_id=session_id)
app/runtime_env.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Default process env vars (Paddle/OpenMP). Call as early as possible after load_dotenv."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+
8
+ def apply_runtime_env_defaults() -> None:
9
+ # Paddle respects OMP_NUM_THREADS at import; setdefault loses if platform already set 2+
10
+ os.environ["OMP_NUM_THREADS"] = "1"
11
+ os.environ["MKL_NUM_THREADS"] = "1"
12
+ os.environ["OPENBLAS_NUM_THREADS"] = "1"
app/session_cache.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """TTL in-memory cache để giảm truy vấn Supabase lặp lại (list session, quyền sở hữu session)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Callable
6
+
7
+ from cachetools import TTLCache
8
+
9
+ from app.logutil import log_step
10
+
11
+ _session_list: TTLCache[str, list[Any]] = TTLCache(maxsize=512, ttl=45)
12
+ _session_owner: TTLCache[tuple[str, str], bool] = TTLCache(maxsize=4096, ttl=45)
13
+
14
+
15
+ def invalidate_for_user(user_id: str) -> None:
16
+ """Xoá cache list session của user (sau create / delete / rename / solve đổi title)."""
17
+ _session_list.pop(user_id, None)
18
+ log_step("cache_invalidate", target="session_list", user_id=user_id)
19
+
20
+
21
+ def invalidate_session_owner(session_id: str, user_id: str) -> None:
22
+ _session_owner.pop((session_id, user_id), None)
23
+ log_step("cache_invalidate", target="session_owner", session_id=session_id, user_id=user_id)
24
+
25
+
26
+ def get_sessions_list_cached(user_id: str, fetch: Callable[[], list[Any]]) -> list[Any]:
27
+ if user_id in _session_list:
28
+ log_step("cache_hit", kind="session_list", user_id=user_id)
29
+ return _session_list[user_id]
30
+ log_step("cache_miss", kind="session_list", user_id=user_id)
31
+ data = fetch()
32
+ _session_list[user_id] = data
33
+ return data
34
+
35
+
36
+ def session_owned_by_user(
37
+ session_id: str,
38
+ user_id: str,
39
+ fetch: Callable[[], bool],
40
+ ) -> bool:
41
+ key = (session_id, user_id)
42
+ if key in _session_owner:
43
+ log_step("cache_hit", kind="session_owner", session_id=session_id)
44
+ return _session_owner[key]
45
+ log_step("cache_miss", kind="session_owner", session_id=session_id)
46
+ ok = fetch()
47
+ _session_owner[key] = ok
48
+ return ok
app/supabase_client.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from supabase import Client, ClientOptions, create_client
3
+ from supabase_auth import SyncMemoryStorage
4
+ from dotenv import load_dotenv
5
+
6
+ load_dotenv()
7
+
8
+ from app.url_utils import sanitize_env
9
+
10
+
11
+ def get_supabase() -> Client:
12
+ """Service-role client for server-side operations (bypasses RLS when policies expect service role)."""
13
+ url = sanitize_env(os.getenv("SUPABASE_URL"))
14
+ key = sanitize_env(os.getenv("SUPABASE_SERVICE_ROLE_KEY") or os.getenv("SUPABASE_KEY"))
15
+ if not url or not key:
16
+ raise RuntimeError(
17
+ "SUPABASE_URL and SUPABASE_SERVICE_ROLE_KEY (or SUPABASE_KEY) must be set"
18
+ )
19
+ return create_client(url, key)
20
+
21
+
22
+ def get_supabase_for_user_jwt(access_token: str) -> Client:
23
+ """
24
+ Client scoped to the logged-in user: PostgREST sends the user's JWT so RLS applies.
25
+ Use SUPABASE_ANON_KEY (publishable), not the service role key.
26
+ """
27
+ url = sanitize_env(os.getenv("SUPABASE_URL"))
28
+ anon = sanitize_env(os.getenv("SUPABASE_ANON_KEY") or os.getenv("NEXT_PUBLIC_SUPABASE_ANON_KEY"))
29
+ if not url or not anon:
30
+ raise RuntimeError(
31
+ "SUPABASE_URL and SUPABASE_ANON_KEY (or NEXT_PUBLIC_SUPABASE_ANON_KEY) must be set "
32
+ "for user-scoped Supabase access"
33
+ )
34
+ base_opts = ClientOptions(storage=SyncMemoryStorage())
35
+ merged_headers = {**dict(base_opts.headers), "Authorization": f"Bearer {access_token}"}
36
+ opts = ClientOptions(storage=SyncMemoryStorage(), headers=merged_headers)
37
+ return create_client(url, anon, opts)
app/url_utils.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Normalize URLs / env strings (HF secrets and copy-paste often include trailing newlines)."""
2
+
3
+
4
+ def sanitize_url(value: str | None) -> str | None:
5
+ if value is None:
6
+ return None
7
+ s = value.strip().replace("\r", "").replace("\n", "").replace("\t", "")
8
+ return s or None
9
+
10
+
11
+ def sanitize_env(value: str | None) -> str | None:
12
+ """Strip whitespace and line breaks from environment-backed strings."""
13
+ return sanitize_url(value)
14
+
15
+
16
+ # OpenAI SDK (>=1.x) requires a non-empty api_key at client construction (Docker build / prewarm has no secrets).
17
+ _OPENAI_API_KEY_BUILD_PLACEHOLDER = "build-placeholder-openrouter-not-for-production"
18
+
19
+
20
+ def openai_compatible_api_key(raw: str | None) -> str:
21
+ """Return sanitized API key, or a placeholder so AsyncOpenAI() can be constructed without env at build time."""
22
+ k = sanitize_env(raw)
23
+ return k if k else _OPENAI_API_KEY_BUILD_PLACEHOLDER
app/websocket_manager.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """WebSocket connection registry and job status notifications (avoid circular imports with main)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from typing import Dict, List
7
+
8
+ from fastapi import WebSocket, WebSocketDisconnect
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ active_connections: Dict[str, List[WebSocket]] = {}
13
+
14
+
15
+ async def notify_status(job_id: str, data: dict) -> None:
16
+ if job_id not in active_connections:
17
+ return
18
+ for connection in list(active_connections[job_id]):
19
+ try:
20
+ await connection.send_json(data)
21
+ except Exception as e:
22
+ logger.error("WS error sending to %s: %s", job_id, e)
23
+
24
+
25
+ def register_websocket_routes(app) -> None:
26
+ """Attach websocket endpoint to the FastAPI app."""
27
+
28
+ @app.websocket("/ws/{job_id}")
29
+ async def websocket_endpoint(websocket: WebSocket, job_id: str) -> None:
30
+ await websocket.accept()
31
+ if job_id not in active_connections:
32
+ active_connections[job_id] = []
33
+ active_connections[job_id].append(websocket)
34
+ try:
35
+ while True:
36
+ await websocket.receive_text()
37
+ except WebSocketDisconnect:
38
+ active_connections[job_id].remove(websocket)
39
+ if not active_connections[job_id]:
40
+ del active_connections[job_id]
clean_ports.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Script to kill all project-related processes for a clean restart
3
+
4
+ echo "🧹 Cleaning up project processes..."
5
+
6
+ # Kill things on ports 8000 (Backend) and 3000 (Frontend)
7
+ PORTS="8000 3000 11020"
8
+ for PORT in $PORTS; do
9
+ PIDS=$(lsof -ti :$PORT)
10
+ if [ ! -z "$PIDS" ]; then
11
+ echo "Killing processes on port $PORT: $PIDS"
12
+ kill -9 $PIDS 2>/dev/null
13
+ fi
14
+ done
15
+
16
+ # Kill by process name
17
+ echo "Killing any remaining Celery, Uvicorn, or Manim processes..."
18
+ pkill -9 -f "celery" 2>/dev/null
19
+ pkill -9 -f "uvicorn" 2>/dev/null
20
+ pkill -9 -f "manim" 2>/dev/null
21
+
22
+ echo "✅ Done. You can now restart your Backend, Worker, and Frontend."
migrations/v4_migration.sql ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- ============================================================
2
+ -- MATHSOLVER v4.0 - Migration Script (Multi-Session & History)
3
+ -- ============================================================
4
+
5
+ -- 1. Profiles Table (Extends Supabase Auth)
6
+ CREATE TABLE IF NOT EXISTS public.profiles (
7
+ id UUID PRIMARY KEY REFERENCES auth.users(id) ON DELETE CASCADE,
8
+ display_name TEXT,
9
+ avatar_url TEXT,
10
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
11
+ updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
12
+ );
13
+
14
+ -- Function to handle new user signup and auto-create profile
15
+ CREATE OR REPLACE FUNCTION public.handle_new_user()
16
+ RETURNS TRIGGER AS $$
17
+ BEGIN
18
+ INSERT INTO public.profiles (id, display_name, avatar_url)
19
+ VALUES (
20
+ NEW.id,
21
+ COALESCE(NEW.raw_user_meta_data->>'full_name', NEW.email),
22
+ NEW.raw_user_meta_data->>'avatar_url'
23
+ );
24
+ RETURN NEW;
25
+ END;
26
+ $$ LANGUAGE plpgsql SECURITY DEFINER;
27
+
28
+ -- Trigger for profile creation
29
+ DROP TRIGGER IF EXISTS on_auth_user_created ON auth.users;
30
+ CREATE TRIGGER on_auth_user_created
31
+ AFTER INSERT ON auth.users
32
+ FOR EACH ROW EXECUTE FUNCTION public.handle_new_user();
33
+
34
+ -- 2. Sessions Table
35
+ CREATE TABLE IF NOT EXISTS public.sessions (
36
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
37
+ user_id UUID NOT NULL REFERENCES auth.users(id) ON DELETE CASCADE,
38
+ title TEXT DEFAULT 'Bài toán mới',
39
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
40
+ updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
41
+ );
42
+
43
+ -- Index for sessions
44
+ CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON public.sessions(user_id);
45
+ CREATE INDEX IF NOT EXISTS idx_sessions_updated_at ON public.sessions(updated_at DESC);
46
+
47
+ -- 3. Messages Table
48
+ CREATE TABLE IF NOT EXISTS public.messages (
49
+ id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
50
+ session_id UUID NOT NULL REFERENCES public.sessions(id) ON DELETE CASCADE,
51
+ role TEXT NOT NULL CHECK (role IN ('user', 'assistant', 'system')),
52
+ type TEXT NOT NULL DEFAULT 'text',
53
+ content TEXT NOT NULL,
54
+ metadata JSONB DEFAULT '{}'::jsonb,
55
+ created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
56
+ );
57
+
58
+ -- Index for messages
59
+ CREATE INDEX IF NOT EXISTS idx_messages_session_id ON public.messages(session_id);
60
+ CREATE INDEX IF NOT EXISTS idx_messages_created_at ON public.messages(session_id, created_at);
61
+
62
+ -- 4. Update Jobs Table
63
+ ALTER TABLE public.jobs ADD COLUMN IF NOT EXISTS user_id UUID REFERENCES auth.users(id);
64
+ ALTER TABLE public.jobs ADD COLUMN IF NOT EXISTS session_id UUID REFERENCES public.sessions(id);
65
+
66
+ -- 5. Row Level Security (RLS)
67
+ ALTER TABLE public.profiles ENABLE ROW LEVEL SECURITY;
68
+ ALTER TABLE public.sessions ENABLE ROW LEVEL SECURITY;
69
+ ALTER TABLE public.messages ENABLE ROW LEVEL SECURITY;
70
+
71
+ -- Polices for public.profiles
72
+ DROP POLICY IF EXISTS "Users view own profile" ON public.profiles;
73
+ CREATE POLICY "Users view own profile" ON public.profiles FOR SELECT USING (auth.uid() = id);
74
+ DROP POLICY IF EXISTS "Users update own profile" ON public.profiles;
75
+ CREATE POLICY "Users update own profile" ON public.profiles FOR UPDATE USING (auth.uid() = id);
76
+
77
+ -- Policies for public.sessions
78
+ DROP POLICY IF EXISTS "Users manage own sessions" ON public.sessions;
79
+ CREATE POLICY "Users manage own sessions" ON public.sessions FOR ALL USING (auth.uid() = user_id);
80
+
81
+ -- Policies for public.messages
82
+ DROP POLICY IF EXISTS "Users view own messages" ON public.messages;
83
+ CREATE POLICY "Users view own messages" ON public.messages FOR ALL USING (
84
+ session_id IN (SELECT id FROM public.sessions WHERE user_id = auth.uid())
85
+ );
86
+
87
+ -- Policies for public.jobs
88
+ DROP POLICY IF EXISTS "Users view own jobs" ON public.jobs;
89
+ CREATE POLICY "Users view own jobs" ON public.jobs FOR ALL USING (auth.uid() = user_id OR user_id IS NULL);
90
+
91
+ -- Grant permissions to public/authenticated
92
+ GRANT ALL ON public.profiles TO authenticated;
93
+ GRANT ALL ON public.sessions TO authenticated;
94
+ GRANT ALL ON public.messages TO authenticated;
95
+ GRANT ALL ON public.jobs TO authenticated;
requirements.txt ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Target: Python 3.11 (see Dockerfile). Used by: FastAPI API, Celery worker, Manim render, OCR/vision stack.
2
+ # Install: pip install -r requirements.txt
3
+
4
+ # --- HTTP API ---
5
+ cachetools>=5.3
6
+ fastapi>=0.115,<1
7
+ uvicorn[standard]>=0.30
8
+ python-multipart>=0.0.9
9
+ python-dotenv>=1.0
10
+ pydantic[email]>=2.4
11
+ email-validator>=2
12
+
13
+ # --- Auth / data / queue ---
14
+ openai>=1.40
15
+ supabase>=2.0
16
+ celery>=5.3
17
+ redis>=5
18
+ httpx>=0.27
19
+ websockets>=12
20
+
21
+ # --- Math & symbolic solver ---
22
+ sympy>=1.12
23
+ numpy>=1.26,<2
24
+ scipy>=1.11
25
+ opencv-python-headless>=4.8,<4.10
26
+
27
+ # --- Video (GeometryScene via CLI) ---
28
+ manim>=0.18,<0.20
29
+
30
+ # --- OCR & vision (orchestrator / legacy /ocr) ---
31
+ pix2tex>=0.1.4
32
+ paddleocr==2.7.3
33
+ paddlepaddle==2.6.2
34
+ ultralytics==8.2.2
run_api_test.sh ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ LOG_FILE="api_test_results.log"
4
+ echo "=== Starting API E2E Test Suite ($(date)) ===" > $LOG_FILE
5
+
6
+ # 1. Start BE Server in background
7
+ echo "[INFO] Starting Backend Server..." | tee -a $LOG_FILE
8
+ export ALLOW_TEST_BYPASS=true
9
+ export LOG_LEVEL=info
10
+ PYTHONPATH=. venv/bin/python -m uvicorn app.main:app --port 8000 > server_debug.log 2>&1 &
11
+ SERVER_PID=$!
12
+
13
+ # 2. Wait for server to be ready
14
+ echo "[INFO] Waiting for server (PID: $SERVER_PID) on port 8000..." | tee -a $LOG_FILE
15
+ MAX_RETRIES=15
16
+ READY=0
17
+ for i in $(seq 1 $MAX_RETRIES); do
18
+ if curl -s http://localhost:8000/ > /dev/null; then
19
+ READY=1
20
+ break
21
+ fi
22
+ sleep 2
23
+ done
24
+
25
+ if [ $READY -eq 0 ]; then
26
+ echo "[ERROR] Server failed to start in time. Check server_debug.log" | tee -a $LOG_FILE
27
+ kill $SERVER_PID
28
+ exit 1
29
+ fi
30
+ echo "[INFO] Server is READY." | tee -a $LOG_FILE
31
+
32
+ # 3. Prepare Test Data
33
+ echo "[INFO] Preparing fresh test data..." | tee -a $LOG_FILE
34
+ PREP_OUTPUT=$(PYTHONPATH=. venv/bin/python scripts/prepare_api_test.py)
35
+ echo "$PREP_OUTPUT" >> $LOG_FILE
36
+
37
+ export TEST_USER_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:USER_ID=" | cut -d'=' -f2)
38
+ export TEST_SESSION_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:SESSION_ID=" | cut -d'=' -f2)
39
+
40
+ if [ -z "$TEST_USER_ID" ] || [ -z "$TEST_SESSION_ID" ]; then
41
+ echo "[ERROR] Failed to prepare test data." | tee -a $LOG_FILE
42
+ kill $SERVER_PID
43
+ exit 1
44
+ fi
45
+
46
+ echo "[INFO] Test Data: User=$TEST_USER_ID, Session=$TEST_SESSION_ID" | tee -a $LOG_FILE
47
+
48
+ # 4. Run Pytest
49
+ echo "[INFO] Running API E2E Tests..." | tee -a $LOG_FILE
50
+ PYTHONPATH=. venv/bin/python -m pytest tests/test_api_real_e2e.py -s >> $LOG_FILE 2>&1
51
+ TEST_EXIT_CODE=$?
52
+
53
+ # 5. Cleanup
54
+ echo "[INFO] Shutting down Server..." | tee -a $LOG_FILE
55
+ kill $SERVER_PID
56
+
57
+ echo "==========================================" | tee -a $LOG_FILE
58
+ if [ $TEST_EXIT_CODE -eq 0 ]; then
59
+ echo "FINAL RESULT: ✅ ALL API TESTS PASSED" | tee -a $LOG_FILE
60
+ else
61
+ echo "FINAL RESULT: ❌ SOME API TESTS FAILED (Code: $TEST_EXIT_CODE)" | tee -a $LOG_FILE
62
+ fi
63
+ echo "==========================================" | tee -a $LOG_FILE
64
+
65
+ exit $TEST_EXIT_CODE
run_full_api_test.sh ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Configuration and Cleanup
4
+ LOG_FILE="full_api_suite.log"
5
+ REPORT_FILE="full_api_test_report.md"
6
+ JSON_RESULTS="temp_suite_results.json"
7
+
8
+ echo "=== Starting Full API Suite Test ($(date)) ===" > $LOG_FILE
9
+
10
+ # Cleanup on exit
11
+ trap 'echo "[INFO] Cleaning up processes..."; kill $SERVER_PID 2>/dev/null; sleep 1' EXIT
12
+
13
+ # 1. Start Server in EAGER MODE + MOCK VIDEO (no Redis/Worker needed)
14
+ echo "[INFO] Starting Backend Server (EAGER + MOCK_VIDEO)..." | tee -a $LOG_FILE
15
+ export ALLOW_TEST_BYPASS=true
16
+ export LOG_LEVEL=info
17
+ export CELERY_TASK_ALWAYS_EAGER=true
18
+ export CELERY_RESULT_BACKEND=rpc://
19
+ export MOCK_VIDEO=true
20
+ PYTHONPATH=. venv/bin/python -m uvicorn app.main:app --port 8000 > server_debug.log 2>&1 &
21
+ SERVER_PID=$!
22
+
23
+ # 2. Wait for server
24
+ echo "[INFO] Waiting for server (PID: $SERVER_PID)..." | tee -a $LOG_FILE
25
+ for i in {1..20}; do
26
+ if curl -s http://localhost:8000/ > /dev/null; then
27
+ echo "[INFO] Server is READY." | tee -a $LOG_FILE
28
+ break
29
+ fi
30
+ sleep 2
31
+ done
32
+
33
+ # 3. Prepare Test Data
34
+ echo "[INFO] Preparing fresh test data..." | tee -a $LOG_FILE
35
+ PREP_OUTPUT=$(PYTHONPATH=. venv/bin/python scripts/prepare_api_test.py)
36
+ export TEST_USER_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:USER_ID=" | cut -d'=' -f2)
37
+ export TEST_SESSION_ID=$(echo "$PREP_OUTPUT" | grep "RESULT:SESSION_ID=" | cut -d'=' -f2)
38
+
39
+ if [ -z "$TEST_USER_ID" ]; then
40
+ echo "[ERROR] Failed to prepare test data." | tee -a $LOG_FILE
41
+ exit 1
42
+ fi
43
+
44
+ # 4. Run Pytest Suite
45
+ echo "[INFO] Executing Full API Suite..." | tee -a $LOG_FILE
46
+ PYTHONPATH=. venv/bin/python -m pytest tests/test_api_full_suite.py -s >> $LOG_FILE 2>&1
47
+ TEST_EXIT_CODE=$?
48
+
49
+ # 5. Shut down server
50
+ echo "[INFO] Shutting down processes..." | tee -a $LOG_FILE
51
+
52
+ # 6. Generate Markdown Report
53
+ echo "[INFO] Generating Markdown Report..." | tee -a $LOG_FILE
54
+ PYTHONPATH=. venv/bin/python scripts/generate_report.py "$JSON_RESULTS" "$REPORT_FILE"
55
+
56
+ echo "==========================================" | tee -a $LOG_FILE
57
+ echo "DONE. Check $REPORT_FILE for results." | tee -a $LOG_FILE
58
+ echo "==========================================" | tee -a $LOG_FILE
59
+
60
+ exit $TEST_EXIT_CODE
scripts/backend_test_suite.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import time
3
+ import json
4
+ import sys
5
+
6
+ BASE_URL = "http://localhost:8000/api/v1"
7
+
8
+ TEST_CASES = [
9
+ {
10
+ "name": "Equilateral Triangle",
11
+ "payload": {"text": "Vẽ tam giác đều cạnh 5.", "request_video": True}
12
+ },
13
+ {
14
+ "name": "Right Triangle (3-4-5)",
15
+ "payload": {"text": "Cho tam giác ABC vuông tại A có AB=3, AC=4. Tính BC.", "request_video": True}
16
+ },
17
+ {
18
+ "name": "Isosceles Triangle",
19
+ "payload": {"text": "Cho tam giác ABC cân tại A có AB=5, BC=6.", "request_video": False}
20
+ },
21
+ {
22
+ "name": "Square",
23
+ "payload": {"text": "Vẽ hình vuông ABCD cạnh 4.", "request_video": True}
24
+ },
25
+ {
26
+ "name": "Invalid Input",
27
+ "payload": {"text": "abcxyz", "request_video": False}
28
+ }
29
+ ]
30
+
31
+ def run_test(test_case):
32
+ print(f"\n[TEST] Running: {test_case['name']}...")
33
+ try:
34
+ start_time = time.time()
35
+ # Create job
36
+ response = requests.post(f"{BASE_URL}/solve", json=test_case['payload'])
37
+ if response.status_code != 200:
38
+ print(f" [FAIL] Initial request failed: {response.text}")
39
+ return False
40
+
41
+ job_id = response.json().get("job_id")
42
+ print(f" [INFO] Job ID: {job_id}")
43
+
44
+ # Poll for completion
45
+ status = "processing"
46
+ max_attempts = 40
47
+ attempts = 0
48
+ while status in ["processing", "solving", "rendering_queued", "rendering"] and attempts < max_attempts:
49
+ time.sleep(5)
50
+ res = requests.get(f"{BASE_URL}/solve/{job_id}")
51
+ data = res.json()
52
+ status = data.get("status")
53
+ print(f" [INFO] Status: {status} (Attempt {attempts+1})")
54
+ if status == "success":
55
+ duration = time.time() - start_time
56
+ print(f" [SUCCESS] Completed in {duration:.2f}s")
57
+ if test_case['payload'].get('request_video'):
58
+ video_url = data.get("result", {}).get("video_url")
59
+ if video_url:
60
+ print(f" [INFO] Video URL: {video_url}")
61
+ else:
62
+ print(" [WARNING] Video requested but no URL found in result.")
63
+ return True
64
+ if status == "error":
65
+ print(f" [FAIL] Solver error: {data.get('result', {}).get('error')}")
66
+ return False
67
+ attempts += 1
68
+
69
+ if attempts >= max_attempts:
70
+ print(" [FAIL] Timeout reached.")
71
+ return False
72
+
73
+ except Exception as e:
74
+ print(f" [ERROR] Exception: {str(e)}")
75
+ return False
76
+
77
+ if __name__ == "__main__":
78
+ results = []
79
+ print("=== MathSolver Backend Test Suite ===")
80
+ for tc in TEST_CASES:
81
+ success = run_test(tc)
82
+ results.append((tc['name'], success))
83
+
84
+ print("\n" + "="*40)
85
+ print("FINAL REPORT:")
86
+ all_passed = True
87
+ for name, success in results:
88
+ status_str = "PASS" if success else "FAIL"
89
+ print(f"- {name}: {status_str}")
90
+ if not success: all_passed = False
91
+
92
+ if all_passed:
93
+ print("\nALL TESTS PASSED! 🎉")
94
+ sys.exit(0)
95
+ else:
96
+ print("\nSOME TESTS FAILED. ❌")
97
+ sys.exit(1)
scripts/generate_report.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ from datetime import datetime
5
+
6
+ def generate_report(json_path, report_path):
7
+ try:
8
+ with open(json_path, 'r') as f:
9
+ data = json.load(f)
10
+
11
+ with open(report_path, 'w') as f:
12
+ f.write('# Báo cáo Kiểm thử API Toàn diện (Full Suite API Report)\n\n')
13
+ f.write(f'**Thời gian chạy:** {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}\n')
14
+ f.write(f'**Kết quả chung:** {"✅ PASS" if all(r.get("success", False) for r in data) else "❌ FAIL"}\n\n')
15
+
16
+ f.write('| ID | Câu hỏi (Query) | Trạng thái | Thời gian (s) | Kết quả / Lỗi |\n')
17
+ f.write('| :--- | :--- | :--- | :--- | :--- |\n')
18
+ for r in data:
19
+ status = "✅ PASS" if r.get("success") else "❌ FAIL"
20
+ elapsed = f"{r.get('elapsed', 0):.2f}"
21
+ query = r.get('query', '-')
22
+
23
+ # Extract analysis or error
24
+ res = r.get('result', {})
25
+ if not isinstance(res, dict):
26
+ res = {}
27
+
28
+ analysis = res.get('semantic_analysis', '-')
29
+ if not r.get("success"):
30
+ analysis = f"**Lỗi:** {r.get('error', '-')}"
31
+
32
+ # Truncate long analysis for table
33
+ short_analysis = (analysis[:100] + '...') if len(analysis) > 100 else analysis
34
+
35
+ f.write(f'| {r["id"]} | {query} | {status} | {elapsed} | {short_analysis} |\n')
36
+
37
+ f.write('\n---\n**Chi tiết Output (DSL & Analysis):**\n')
38
+ for r in data:
39
+ if r.get('success'):
40
+ res = r.get('result', {})
41
+ if not isinstance(res, dict):
42
+ continue
43
+
44
+ f.write(f"\n### Case {r['id']}: {r.get('query')}\n")
45
+ f.write(f"**Semantic Analysis:**\n{res.get('semantic_analysis', '-')}\n\n")
46
+ f.write(f"**Geometry DSL:**\n```\n{res.get('geometry_dsl', '-')}\n```\n")
47
+
48
+ # v5.1 Solution info
49
+ sol = res.get('solution')
50
+ if sol and isinstance(sol, dict):
51
+ f.write("**Solution (v5.1):**\n")
52
+ f.write(f"- **Answer:** {sol.get('answer', 'N/A')}\n")
53
+ f.write("- **Steps:**\n")
54
+ steps = sol.get('steps', [])
55
+ if steps:
56
+ for step in steps:
57
+ f.write(f" - {step}\n")
58
+ else:
59
+ f.write(" - (Không có bước giải cụ thể)\n")
60
+
61
+ if sol.get('symbolic_expression'):
62
+ f.write(f"- **Symbolic:** `{sol.get('symbolic_expression')}`\n")
63
+ f.write("\n")
64
+
65
+ print(f'Report generated: {report_path}')
66
+ except Exception as e:
67
+ print(f'Error generating report: {e}')
68
+
69
+ if __name__ == "__main__":
70
+ if len(sys.argv) < 3:
71
+ print("Usage: python generate_report.py <json_results> <report_output>")
72
+ sys.exit(1)
73
+ generate_report(sys.argv[1], sys.argv[2])
scripts/prepare_api_test.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ import sys
3
+ import os
4
+
5
+ # Add parent dir to path to import app modules
6
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7
+
8
+ from app.supabase_client import get_supabase
9
+
10
+ def prepare():
11
+ supabase = get_supabase()
12
+ # Use existing valid user to avoid foreign key violation on sessions.user_id
13
+ user_id = "8cd3adb0-7964-4575-949c-d0cadcd8b679"
14
+ session_id = str(uuid.uuid4())
15
+
16
+ print(f"Using existing test user: {user_id}")
17
+
18
+ print(f"Creating fresh test session: {session_id}")
19
+ # Insert session
20
+ supabase.table("sessions").insert({
21
+ "id": session_id,
22
+ "user_id": user_id,
23
+ "title": f"Fresh API Test {session_id[:8]}"
24
+ }).execute()
25
+
26
+ # Return IDs for the test script
27
+ print(f"RESULT:USER_ID={user_id}")
28
+ print(f"RESULT:SESSION_ID={session_id}")
29
+
30
+ if __name__ == "__main__":
31
+ prepare()
scripts/prewarm_models.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Download and load all heavy models during Docker build (YOLO, PaddleOCR, Pix2Tex, agents).
4
+ Fails the image build if initialization fails.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import logging
10
+ import os
11
+ import sys
12
+
13
+ # Ensure imports work when run as `python scripts/prewarm_models.py` from WORKDIR
14
+ _APP_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
15
+ if _APP_ROOT not in sys.path:
16
+ sys.path.insert(0, _APP_ROOT)
17
+
18
+ os.chdir(_APP_ROOT)
19
+
20
+ from dotenv import load_dotenv
21
+
22
+ load_dotenv()
23
+
24
+ from app.runtime_env import apply_runtime_env_defaults
25
+
26
+ apply_runtime_env_defaults()
27
+
28
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s %(name)s | %(message)s")
29
+
30
+ logger = logging.getLogger("prewarm")
31
+
32
+
33
+ def main() -> None:
34
+ from agents.orchestrator import Orchestrator
35
+
36
+ logger.info("Constructing Orchestrator (full agent + model load)...")
37
+ Orchestrator()
38
+ logger.info("Prewarm finished successfully.")
39
+
40
+
41
+ if __name__ == "__main__":
42
+ main()
scripts/test_engine_direct.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import json
4
+ import logging
5
+ import sys
6
+
7
+ # Add root directory to path to import app and agents
8
+ sys.path.append("/Volumes/WorkSpace/Project/MathSolver/backend")
9
+
10
+ # Configure logging to stdout
11
+ logging.basicConfig(level=logging.DEBUG)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ from agents.orchestrator import Orchestrator
15
+
16
+ async def main():
17
+ orch = Orchestrator()
18
+ text = "Vẽ tam giác đều cạnh 5."
19
+ job_id = "test_direct_equilateral"
20
+
21
+ print(f"\n--- Testing Orchestrator Direct: {text} ---")
22
+
23
+ async def status_cb(status):
24
+ print(f" [STATUS] {status}")
25
+
26
+ try:
27
+ result = await orch.run(text, job_id=job_id, status_callback=status_cb, request_video=False)
28
+ print("\n--- Final Result ---")
29
+ print(json.dumps(result, indent=2))
30
+ except Exception as e:
31
+ print(f"\n--- ERROR ---")
32
+ import traceback
33
+ traceback.print_exc()
34
+
35
+ if __name__ == "__main__":
36
+ asyncio.run(main())
setup.sh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # MathSolver v3.1 Setup Script for macOS
4
+
5
+ echo "🚀 Starting Environment Setup..."
6
+
7
+ # 1. System Dependencies (Homebrew)
8
+ if command -v brew >/dev/null 2>&1; then
9
+ echo "📦 Installing system dependencies via Homebrew..."
10
+ brew install pango pkg-config glib librsvg
11
+ else
12
+ echo "⚠️ Homebrew not found. Please install it first: https://brew.sh/"
13
+ exit 1
14
+ fi
15
+
16
+ # 2. Python SSL Certificates
17
+ PYTHON_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
18
+ CERT_FILE="/Applications/Python ${PYTHON_VERSION}/Install Certificates.command"
19
+
20
+ if [ -f "$CERT_FILE" ]; then
21
+ echo "🔐 Installing Python SSL certificates..."
22
+ sh "$CERT_FILE"
23
+ else
24
+ echo "ℹ️ SSL certificate installer not found at $CERT_FILE. Skipping..."
25
+ fi
26
+
27
+ # 3. Virtual Environment
28
+ echo "🐍 Setting up Python Virtual Environment..."
29
+ cd backend
30
+ python3 -m venv venv
31
+ source venv/bin/activate
32
+
33
+ # 4. Pip packages
34
+ echo "📦 Installing Python packages..."
35
+ pip install --upgrade pip
36
+ pip install -r requirements.txt
37
+
38
+ # 5. Fix ManimPango (Crucial for macOS arm64)
39
+ echo "🛠️ Rebuilding ManimPango from source to ensure library linking..."
40
+ pip install --no-cache-dir --force-reinstall --no-binary manimpango manimpango
41
+
42
+ echo "✅ Setup Complete!"
43
+ echo "To start the backend, run: source venv/bin/activate && uvicorn app.main:app --reload"
solver/dsl_parser.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import logging
3
+ from typing import List, Tuple, Dict, Any
4
+ from .models import Point, Constraint
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class DSLParser:
10
+ def parse(self, text: str) -> Tuple[List[Point], List[Constraint], bool]:
11
+ """Parse DSL text into points and constraints. Stateless per call."""
12
+ points: Dict[str, Point] = {}
13
+ explicit_point_ids: List[str] = []
14
+ constraints: List[Constraint] = []
15
+ polygon_order: List[str] = []
16
+ circles: List[Dict[str, Any]] = []
17
+ segments: List[List[str]] = []
18
+ lines_ext: List[List[str]] = []
19
+ rays: List[List[str]] = []
20
+ is_3d = False
21
+
22
+ logger.info("==[DSLParser] Parsing DSL input==")
23
+ logger.debug(f"[DSLParser] Raw DSL:\n{text}")
24
+
25
+ lines = text.strip().split('\n')
26
+ for line in lines:
27
+ line = line.strip()
28
+ if not line or line.startswith('//') or line.startswith('#'):
29
+ continue
30
+
31
+ # POINT(A) or POINT(A, 0, 0, 5)
32
+ m = re.match(r'POINT\((\w+)(?:,\s*([\d\.-]+),\s*([\d\.-]+)(?:,\s*([\d\.-]+))?)?\)', line)
33
+ if m:
34
+ name = m.group(1)
35
+ x = float(m.group(2)) if m.group(2) else None
36
+ y = float(m.group(3)) if m.group(3) else None
37
+ z = float(m.group(4)) if m.group(4) else None
38
+ if z is not None:
39
+ is_3d = True
40
+ points[name] = Point(id=name, x=x, y=y, z=z)
41
+ if name not in explicit_point_ids:
42
+ explicit_point_ids.append(name)
43
+ logger.debug(f"[DSLParser] + POINT: {name} ({x}, {y}, {z})")
44
+ continue
45
+
46
+ # LENGTH(AB, 5)
47
+ m = re.match(r'LENGTH\((\w+),\s*([\d\.]+)\)', line)
48
+ if m:
49
+ target, value = m.group(1), float(m.group(2))
50
+ pts = [target[i:i+1] for i in range(len(target))]
51
+ constraints.append(Constraint(type='length', targets=pts, value=value))
52
+ logger.debug(f"[DSLParser] + LENGTH: {pts} = {value}")
53
+ continue
54
+
55
+ # ANGLE(A, 90) or ANGLE(A, 90deg)
56
+ m = re.match(r'ANGLE\((\w+),\s*([\d\.]+)(?:deg)?\)', line)
57
+ if m:
58
+ target, value = m.group(1), float(m.group(2))
59
+ constraints.append(Constraint(type='angle', targets=[target], value=value))
60
+ logger.debug(f"[DSLParser] + ANGLE: vertex={target}, degrees={value}")
61
+ continue
62
+
63
+ # PARALLEL(AB, CD)
64
+ m = re.match(r'PARALLEL\((\w+),\s*(\w+)\)', line)
65
+ if m:
66
+ seg1, seg2 = m.group(1), m.group(2)
67
+ constraints.append(Constraint(type='parallel', targets=list(seg1) + list(seg2), value=0))
68
+ logger.debug(f"[DSLParser] + PARALLEL: {seg1} || {seg2}")
69
+ continue
70
+
71
+ # PERPENDICULAR(AB, CD)
72
+ m = re.match(r'PERPENDICULAR\((\w+),\s*(\w+)\)', line)
73
+ if m:
74
+ seg1, seg2 = m.group(1), m.group(2)
75
+ constraints.append(Constraint(type='perpendicular', targets=list(seg1) + list(seg2), value=0))
76
+ logger.debug(f"[DSLParser] + PERPENDICULAR: {seg1} _|_ {seg2}")
77
+ continue
78
+
79
+ # MIDPOINT(M, AB) — M is midpoint of AB
80
+ m = re.match(r'MIDPOINT\((\w+),\s*(\w+)\)', line)
81
+ if m:
82
+ mid, seg = m.group(1), m.group(2)
83
+ if mid not in points:
84
+ points[mid] = Point(id=mid)
85
+ pts = [mid] + [seg[i:i+1] for i in range(len(seg))]
86
+ constraints.append(Constraint(type='midpoint', targets=pts, value=0))
87
+ logger.debug(f"[DSLParser] + MIDPOINT: {mid} = mid({seg})")
88
+ continue
89
+
90
+ # SECTION(E, A, C, 0.66) — E lies on AC s.t. AE = 0.66 * AC
91
+ m = re.match(r'SECTION\((\w+),\s*(\w+),\s*(\w+),\s*([\d\.-]+)\)', line)
92
+ if m:
93
+ target, p1, p2, k = m.group(1), m.group(2), m.group(3), float(m.group(4))
94
+ if target not in points:
95
+ points[target] = Point(id=target)
96
+ constraints.append(Constraint(type='section', targets=[target, p1, p2], value=k))
97
+ logger.debug(f"[DSLParser] + SECTION: {target} = {p1} + {k}({p2}-{p1})")
98
+ continue
99
+
100
+ # CIRCLE(O, r)
101
+ m = re.match(r'CIRCLE\((\w+),\s*([\d\.]+)\)', line)
102
+ if m:
103
+ center, radius = m.group(1), float(m.group(2))
104
+ if center not in points:
105
+ points[center] = Point(id=center)
106
+ constraints.append(Constraint(type='circle', targets=[center], value=radius))
107
+ circles.append({"center": center, "radius": radius})
108
+ logger.debug(f"[DSLParser] + CIRCLE: center={center}, r={radius}")
109
+ continue
110
+
111
+ # POLYGON_ORDER(A, B, C, D) — thứ tự nối điểm để vẽ đa giác
112
+ m = re.match(r'POLYGON_ORDER\(([^)]+)\)', line)
113
+ if m:
114
+ polygon_order = [p.strip() for p in m.group(1).split(',')]
115
+ logger.debug(f"[DSLParser] + POLYGON_ORDER: {polygon_order}")
116
+ continue
117
+
118
+ # SEGMENT(M, N) — đoạn thẳng phụ cần vẽ
119
+ m = re.match(r'SEGMENT\((\w+),\s*(\w+)\)', line)
120
+ if m:
121
+ p1, p2 = m.group(1), m.group(2)
122
+ segments.append([p1, p2])
123
+ constraints.append(Constraint(type='segment', targets=[p1, p2], value=0))
124
+ logger.debug(f"[DSLParser] + SEGMENT: {p1}—{p2}")
125
+ continue
126
+
127
+ # LINE(A, B) — infinite line
128
+ m = re.match(r'LINE\((\w+),\s*(\w+)\)', line)
129
+ if m:
130
+ p1, p2 = m.group(1), m.group(2)
131
+ lines_ext.append([p1, p2])
132
+ constraints.append(Constraint(type='line', targets=[p1, p2], value=0))
133
+ logger.debug(f"[DSLParser] + LINE: {p1}-{p2}")
134
+ continue
135
+
136
+ # RAY(A, B) — ray AB starting at A
137
+ m = re.match(r'RAY\((\w+),\s*(\w+)\)', line)
138
+ if m:
139
+ p1, p2 = m.group(1), m.group(2)
140
+ rays.append([p1, p2])
141
+ constraints.append(Constraint(type='ray', targets=[p1, p2], value=0))
142
+ logger.debug(f"[DSLParser] + RAY: {p1}->{p2}")
143
+ continue
144
+
145
+ # TRIANGLE(ABC) / PYRAMID(S_ABCD) / PRISM(ABC_DEF)
146
+ m = re.match(r'(TRIANGLE|PYRAMID|PRISM)\(([^)]+)\)', line)
147
+ if m:
148
+ pt_type = m.group(1)
149
+ targets = m.group(2)
150
+ if pt_type in ["PYRAMID", "PRISM"]:
151
+ is_3d = True
152
+ if pt_type == "TRIANGLE":
153
+ if not polygon_order: polygon_order = list(targets)
154
+ elif pt_type == "PYRAMID":
155
+ # S_ABCD -> S is apex, ABCD is base
156
+ if "_" in targets:
157
+ apex, base = targets.split("_")
158
+ # Add segments from apex to all base points
159
+ for p in base:
160
+ segments.append([apex, p])
161
+ constraints.append(Constraint(type='segment', targets=[apex, p], value=0))
162
+ if not polygon_order: polygon_order = list(base)
163
+ elif pt_type == "PRISM":
164
+ # ABC_DEF -> two bases
165
+ if "_" in targets:
166
+ b1, b2 = targets.split("_")
167
+ for p1, p2 in zip(b1, b2):
168
+ segments.append([p1, p2])
169
+ constraints.append(Constraint(type='segment', targets=[p1, p2], value=0))
170
+ logger.debug(f"[DSLParser] + {pt_type}: {targets}")
171
+ continue
172
+
173
+ # SPHERE(O, r)
174
+ m = re.match(r'SPHERE\((\w+),\s*([\d\.]+)\)', line)
175
+ if m:
176
+ is_3d = True
177
+ center, radius = m.group(1), float(m.group(2))
178
+ if center not in points:
179
+ points[center] = Point(id=center)
180
+ constraints.append(Constraint(type='sphere', targets=[center], value=radius))
181
+ logger.debug(f"[DSLParser] + SPHERE: center={center}, r={radius}")
182
+ continue
183
+
184
+ logger.warning(f"[DSLParser] ? Unrecognized DSL line: '{line}'")
185
+
186
+ logger.info(f"[DSLParser] Parsed {len(points)} points, {len(constraints)} constraints.")
187
+
188
+ # Safety sweep: Ensure all points referenced in constraints actually exist in the points dictionary
189
+ for c in constraints:
190
+ for pid in c.targets:
191
+ # Some targets might be values or comma-separated strings (handled elsewhere),
192
+ # but most are single-character point IDs.
193
+ if isinstance(pid, str) and len(pid) == 1 and pid not in points:
194
+ points[pid] = Point(id=pid)
195
+ logger.debug(f"[DSLParser] ! Auto-declared missing point from constraint: {pid}")
196
+
197
+ # Attach metadata to a synthetic constraint for downstream use
198
+ if polygon_order:
199
+ constraints.append(Constraint(type='polygon_order', targets=polygon_order, value=0))
200
+ elif explicit_point_ids:
201
+ # Re-use polygon_order as a carrier for explicit points IF no real order was specified
202
+ constraints.append(Constraint(type='explicit_points', targets=explicit_point_ids, value=0))
203
+
204
+ # Add auxiliary metadata for lines and rays
205
+ if lines_ext:
206
+ constraints.append(Constraint(type='lines_metadata', targets=[",".join(l) for l in lines_ext], value=0))
207
+ if rays:
208
+ constraints.append(Constraint(type='rays_metadata', targets=[",".join(l) for l in rays], value=0))
209
+
210
+ return list(points.values()), constraints, is_3d
solver/engine.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sympy as sp
2
+ import numpy as np
3
+ import logging
4
+ import string
5
+ from typing import List, Dict, Any
6
+ from .models import Point, Constraint
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+
11
+ class GeometryEngine:
12
+ def solve(self, points: List[Point], constraints: List[Constraint], is_3d: bool = False) -> Dict[str, Any] | None:
13
+ if not points:
14
+ logger.error("[GeometryEngine] No points to solve.")
15
+ return None
16
+
17
+ logger.info(f"==[GeometryEngine] Starting solve with {len(points)} points, {len(constraints)} constraints (is_3d={is_3d})==")
18
+
19
+ # ── Separate metadata constraints from real ones ──────────────────────
20
+ polygon_order: List[str] = []
21
+ circles_meta: List[Dict] = []
22
+ segments_meta: List[List[str]] = []
23
+ real_constraints: List[Constraint] = []
24
+
25
+ for c in constraints:
26
+ if c.type == 'polygon_order':
27
+ polygon_order = list(c.targets)
28
+ elif c.type == 'explicit_points' and not polygon_order:
29
+ polygon_order = list(c.targets)
30
+ elif c.type == 'circle':
31
+ circles_meta.append({"center": c.targets[0], "radius": float(c.value)})
32
+ real_constraints.append(c)
33
+ elif c.type == 'segment':
34
+ segments_meta.append(list(c.targets))
35
+ # don't add to equations — pure drawing annotation
36
+ elif c.type == 'lines_metadata':
37
+ lines_meta_list = [t.split(',') for t in c.targets]
38
+ real_constraints.append(c) # for passing to builder? or just keep here
39
+ elif c.type == 'rays_metadata':
40
+ rays_meta_list = [t.split(',') for t in c.targets]
41
+ real_constraints.append(c)
42
+ else:
43
+ real_constraints.append(c)
44
+
45
+ # ── Setup symbols ─────────────────────────────────────────────────────
46
+ point_vars: Dict[str, tuple] = {}
47
+ equations = []
48
+
49
+ # Convert to list for stable indexing and to handle both Dict and List inputs
50
+ pt_list = list(points.values()) if isinstance(points, dict) else points
51
+
52
+ for p in pt_list:
53
+ x = sp.Symbol(f"{p.id}_x")
54
+ y = sp.Symbol(f"{p.id}_y")
55
+ z = sp.Symbol(f"{p.id}_z")
56
+ point_vars[p.id] = (x, y, z)
57
+ logger.debug(f"[GeometryEngine] Symbol: ({p.id}_x, {p.id}_y, {p.id}_z)")
58
+
59
+ # If 2D problem, pin all z to 0 immediately
60
+ if not is_3d:
61
+ equations.append(z)
62
+
63
+ # ── Anchor logic to fix translation + rotation DOF ────────────────────
64
+ # Skip anchoring if points already have explicit coordinates that fix DOFs
65
+
66
+ if len(pt_list) > 0:
67
+ p1 = pt_list[0]
68
+ # Translation: fix p1 at (0,0) or (0,0,0)
69
+ if p1.x is None: equations.append(point_vars[p1.id][0]); logger.debug(f"Anchor {p1.id}_x=0")
70
+ if p1.y is None: equations.append(point_vars[p1.id][1]); logger.debug(f"Anchor {p1.id}_y=0")
71
+ if is_3d and p1.z is None:
72
+ equations.append(point_vars[p1.id][2]); logger.debug(f"Anchor {p1.id}_z=0")
73
+
74
+ if len(pt_list) > 1:
75
+ p2 = pt_list[1]
76
+ # Rotation: fix p2 on X-axis (y=0)
77
+ if p2.y is None: equations.append(point_vars[p2.id][1]); logger.debug(f"Anchor {p2.id}_y=0")
78
+ if is_3d and p2.z is None:
79
+ equations.append(point_vars[p2.id][2]); logger.debug(f"Anchor {p2.id}_z=0")
80
+
81
+ if is_3d and len(pt_list) > 2:
82
+ p3 = pt_list[2]
83
+ # Planar rotation: fix p3 on XY-plane (z=0)
84
+ if p3.z is None: equations.append(point_vars[p3.id][2]); logger.debug(f"Anchor {p3.id}_z=0")
85
+
86
+ # ── Build equations from explicit point coordinates ──────────────────
87
+ for p in pt_list:
88
+ if p.x is not None:
89
+ equations.append(point_vars[p.id][0] - p.x)
90
+ if p.y is not None:
91
+ equations.append(point_vars[p.id][1] - p.y)
92
+ if p.z is not None:
93
+ equations.append(point_vars[p.id][2] - p.z)
94
+
95
+ # ── Build equations from constraints ──────────────────────────────────
96
+ for c in real_constraints:
97
+ logger.debug(f"[GeometryEngine] Processing constraint: type={c.type}, targets={c.targets}, value={c.value}")
98
+
99
+ if c.type == 'length' and len(c.targets) == 2:
100
+ p1, p2 = c.targets
101
+ if p1 not in point_vars or p2 not in point_vars:
102
+ logger.warning(f"[GeometryEngine] Skip length: {c.targets} not in symbols.")
103
+ continue
104
+ v1, v2 = point_vars[p1], point_vars[p2]
105
+ # 3D distance
106
+ eq = (v2[0]-v1[0])**2 + (v2[1]-v1[1])**2 + (v2[2]-v1[2])**2 - float(c.value)**2
107
+ equations.append(eq)
108
+ logger.debug(f"[GeometryEngine] -> Length eq (3D): |{p1}{p2}|² = {c.value}²")
109
+
110
+ elif c.type == 'angle' and len(c.targets) >= 1:
111
+ # In 3D, 'angle' usually refers to the angle between two vectors (e.g., ∠BAC)
112
+ v_name = c.targets[0]
113
+ if v_name not in point_vars:
114
+ continue
115
+ # For simplicity, we assume the next two points in targets or fallback to first 2 others
116
+ if len(c.targets) >= 3:
117
+ p1_name, p2_name = c.targets[1], c.targets[2]
118
+ else:
119
+ other_pts = [p.id for p in pt_list if p.id != v_name][:2]
120
+ if len(other_pts) < 2: continue
121
+ p1_name, p2_name = other_pts
122
+
123
+ pV = point_vars[v_name]
124
+ p1_vars = point_vars[p1_name]
125
+ p2_vars = point_vars[p2_name]
126
+
127
+ # Vectors V1 and V2
128
+ v1 = [p1_vars[i] - pV[i] for i in range(3)]
129
+ v2 = [p2_vars[i] - pV[i] for i in range(3)]
130
+
131
+ # Dot product relation: v1.v2 = |v1||v2| cos(theta)
132
+ # But we use the tangent relation or square it to avoid sqrt if possible
133
+ # If 90 deg: dot product = 0
134
+ if abs(float(c.value) - 90.0) < 1e-9:
135
+ eq = sum(v1[i]*v2[i] for i in range(3))
136
+ logger.debug(f"[GeometryEngine] -> Angle eq at {v_name} (90° dot=0)")
137
+ else:
138
+ # Generic angle using law of cosines (squared)
139
+ cos_val = np.cos(np.deg2rad(float(c.value)))
140
+ d1_sq = sum(v1[i]**2 for i in range(3))
141
+ d2_sq = sum(v2[i]**2 for i in range(3))
142
+ dot = sum(v1[i]*v2[i] for i in range(3))
143
+ eq = dot**2 - (cos_val**2) * d1_sq * d2_sq
144
+ # Note: this allows theta and 180-theta.
145
+ # Better: dot - cos(theta) * sqrt(d1_sq * d2_sq) = 0, but that has sqrt.
146
+ logger.debug(f"[GeometryEngine] -> Angle eq at {v_name} ({c.value}° cos² relation)")
147
+ equations.append(eq)
148
+
149
+ elif c.type == 'parallel' and len(c.targets) == 4:
150
+ pA, pB, pC, pD = c.targets
151
+ if any(t not in point_vars for t in [pA, pB, pC, pD]): continue
152
+ va, vb, vc, vd = point_vars[pA], point_vars[pB], point_vars[pC], point_vars[pD]
153
+ # AB || CD means vector(AB) = lambda * vector(CD)
154
+ # In 3D, cross product = 0. (b-a) x (d-c) = 0
155
+ v1 = [vb[i]-va[i] for i in range(3)]
156
+ v2 = [vd[i]-vc[i] for i in range(3)]
157
+ # Cross product components:
158
+ equations.append(v1[1]*v2[2] - v1[2]*v2[1])
159
+ equations.append(v1[2]*v2[0] - v1[0]*v2[2])
160
+ equations.append(v1[0]*v2[1] - v1[1]*v2[0])
161
+ logger.debug(f"[GeometryEngine] -> Parallel eq (3D cross=0): {pA}{pB} || {pC}{pD}")
162
+
163
+ elif c.type == 'perpendicular' and len(c.targets) == 4:
164
+ pA, pB, pC, pD = c.targets
165
+ if any(t not in point_vars for t in [pA, pB, pC, pD]): continue
166
+ va, vb, vc, vd = point_vars[pA], point_vars[pB], point_vars[pC], point_vars[pD]
167
+ # Dot product = 0
168
+ dot = sum((vb[i]-va[i])*(vd[i]-vc[i]) for i in range(3))
169
+ equations.append(dot)
170
+ logger.debug(f"[GeometryEngine] -> Perpendicular eq (3D dot=0): {pA}{pB} ⊥ {pC}{pD}")
171
+
172
+ elif c.type == 'midpoint' and len(c.targets) == 3:
173
+ pM, pA, pB = c.targets
174
+ if any(t not in point_vars for t in [pM, pA, pB]): continue
175
+ vM, vA, vB = point_vars[pM], point_vars[pA], point_vars[pB]
176
+ for i in range(3):
177
+ equations.append(2*vM[i] - vA[i] - vB[i])
178
+ logger.debug(f"[GeometryEngine] -> Midpoint eq (3D): {pM} = mid({pA},{pB})")
179
+
180
+ elif c.type == 'section' and len(c.targets) == 3:
181
+ pE, pA, pC = c.targets
182
+ if any(t not in point_vars for t in [pE, pA, pC]): continue
183
+ vE, vA, vC = point_vars[pE], point_vars[pA], point_vars[pC]
184
+ k = float(c.value)
185
+ for i in range(3):
186
+ equations.append(vE[i] - (vA[i] + k * (vC[i] - vA[i])))
187
+ logger.debug(f"[GeometryEngine] -> Section eq (3D): {pE} = {pA} + {k}({pC}-{pA})")
188
+
189
+ elif c.type == 'circle':
190
+ # Circle doesn't add position constraints for center (already a point)
191
+ logger.debug(f"[GeometryEngine] -> Circle: center={c.targets[0]}, r={c.value} (meta only)")
192
+
193
+ all_vars = []
194
+ for v in point_vars.values():
195
+ all_vars.extend(v)
196
+
197
+ n_eqs = len(equations)
198
+ n_vars = len(all_vars)
199
+ logger.info(f"[GeometryEngine] Built {n_eqs} equations for {n_vars} unknowns.")
200
+
201
+ # ── Strategy 1: SymPy symbolic ───────────────────────────────────────
202
+ coords = self._try_symbolic(equations, all_vars, point_vars)
203
+
204
+ # Extract lines/rays from constraints for builder
205
+ lines_ext = []
206
+ rays_ext = []
207
+ for c in constraints:
208
+ if c.type == 'lines_metadata':
209
+ lines_ext = [t.split(',') for t in c.targets]
210
+ if c.type == 'rays_metadata':
211
+ rays_ext = [t.split(',') for t in c.targets]
212
+
213
+ if coords:
214
+ return self._build_result(coords, polygon_order, circles_meta, segments_meta, lines_ext, rays_ext, pt_list)
215
+
216
+ # ── Strategy 2: Numerical nsolve ─────────────────────────────────────
217
+ if n_eqs == n_vars:
218
+ coords = self._try_nsolve(equations, all_vars, point_vars, n_vars)
219
+ if coords:
220
+ return self._build_result(coords, polygon_order, circles_meta, segments_meta, lines_ext, rays_ext, pt_list)
221
+
222
+ # ── Strategy 3: Scipy least-squares ─────────────────────────────────
223
+ coords = self._try_lsq(equations, all_vars, point_vars, n_vars)
224
+ if coords:
225
+ return self._build_result(coords, polygon_order, circles_meta, segments_meta, lines_ext, rays_ext, pt_list)
226
+
227
+ # ── Strategy 4: Differential evolution ──────────────────────────────
228
+ coords = self._try_global(equations, all_vars, point_vars, n_vars)
229
+ if coords:
230
+ return self._build_result(coords, polygon_order, circles_meta, segments_meta, lines_ext, rays_ext, pt_list)
231
+
232
+ logger.error("[GeometryEngine] All strategies exhausted.")
233
+ return None
234
+
235
+ # ─── Solving strategies ──────────────────────────────────────────────────
236
+
237
+ def _try_symbolic(self, equations, all_vars, point_vars):
238
+ # Optimization: SymPy's symbolic solver becomes extremely slow for many variables.
239
+ # For 3D problems (usually 12-18+ variables), we prefer using numerical methods directly.
240
+ if len(all_vars) > 10:
241
+ logger.info(f"[GeometryEngine] Strategy 1: Skipping symbolic solve due to high variable count ({len(all_vars)}).")
242
+ return None
243
+
244
+ try:
245
+ solution = sp.solve(equations, all_vars, dict=True)
246
+ if solution:
247
+ res = solution[0]
248
+ logger.info("[GeometryEngine] Strategy 1 (SymPy symbolic): SUCCESS.")
249
+ logger.debug(f"[GeometryEngine] Symbolic solution: {res}")
250
+ return {pid: [float(res.get(vx, 0.0)), float(res.get(vy, 0.0)), float(res.get(vz, 0.0))]
251
+ for pid, (vx, vy, vz) in point_vars.items()}
252
+ else:
253
+ logger.warning("[GeometryEngine] Strategy 1 returned no solution. Trying numerical...")
254
+ except Exception as e:
255
+ logger.warning(f"[GeometryEngine] Strategy 1 threw exception: {e}. Trying numerical...")
256
+ return None
257
+
258
+ def _try_nsolve(self, equations, all_vars, point_vars, n_vars):
259
+ MAX_NSOLVE_ATTEMPTS = 15
260
+ logger.info(f"[GeometryEngine] Strategy 2 (nsolve): square system ({n_vars}x{n_vars}). Trying {MAX_NSOLVE_ATTEMPTS} random starts...")
261
+ import random
262
+ for attempt in range(MAX_NSOLVE_ATTEMPTS):
263
+ try:
264
+ # Use varying scales for the random guesses to handle different problem sizes
265
+ scale = 10 if attempt < 5 else (100 if attempt < 10 else 1)
266
+ guesses = [random.uniform(-scale, scale) for _ in all_vars]
267
+ sol_vals = sp.nsolve(equations, all_vars, guesses, tol=1e-6, maxsteps=1000)
268
+ res = {var: float(val) for var, val in zip(all_vars, sol_vals)}
269
+ logger.info(f"[GeometryEngine] Strategy 2 (nsolve): SUCCESS on attempt {attempt + 1}.")
270
+ return {pid: [float(res.get(vx, 0.0)), float(res.get(vy, 0.0)), float(res.get(vz, 0.0))]
271
+ for pid, (vx, vy, vz) in point_vars.items()}
272
+ except Exception as e:
273
+ logger.debug(f"[GeometryEngine] nsolve attempt {attempt + 1} failed: {e}")
274
+ return None
275
+
276
+ def _try_lsq(self, equations, all_vars, point_vars, n_vars):
277
+ logger.info("[GeometryEngine] Strategy 3 (scipy least-squares): minimizing residuals...")
278
+ try:
279
+ from scipy.optimize import minimize
280
+ eq_funcs = [sp.lambdify(all_vars, eq, 'numpy') for eq in equations]
281
+
282
+ def objective(x):
283
+ return sum(float(f(*x))**2 for f in eq_funcs)
284
+
285
+ best_res, best_val = None, float('inf')
286
+ # Increase restarts for better coverage of local minima
287
+ for i in range(12):
288
+ if i == 0:
289
+ x0 = [1.0]*n_vars
290
+ elif i < 4:
291
+ x0 = [np.random.uniform(-10, 10) for _ in range(n_vars)]
292
+ else:
293
+ x0 = [np.random.uniform(-100, 100) for _ in range(n_vars)]
294
+
295
+ res = minimize(objective, x0, method='L-BFGS-B')
296
+ if res.fun < best_val:
297
+ best_val, best_res = res.fun, res
298
+ if best_val < 1e-6:
299
+ break
300
+
301
+ TOLERANCE = 1e-4
302
+ logger.info(f"[GeometryEngine] Strategy 3: best residual = {best_val:.2e} (tol={TOLERANCE})")
303
+ if best_val < TOLERANCE:
304
+ res = {var: float(val) for var, val in zip(all_vars, best_res.x)}
305
+ logger.info("[GeometryEngine] Strategy 3 (least-squares): SUCCESS.")
306
+ return {pid: [float(res.get(vx, 0)), float(res.get(vy, 0)), float(res.get(vz, 0))]
307
+ for pid, (vx, vy, vz) in point_vars.items()}
308
+ else:
309
+ logger.warning(f"[GeometryEngine] Strategy 3 failed: residual {best_val:.2e} > {TOLERANCE}")
310
+ except Exception as e:
311
+ logger.error(f"[GeometryEngine] Strategy 3 threw exception: {e}")
312
+ return None
313
+
314
+ def _try_global(self, equations, all_vars, point_vars, n_vars):
315
+ logger.info("[GeometryEngine] Strategy 4 (Differential Evolution): global search...")
316
+ try:
317
+ from scipy.optimize import differential_evolution
318
+ bounds = [(-20, 20)] * n_vars
319
+ eq_funcs = [sp.lambdify(all_vars, eq, 'numpy') for eq in equations]
320
+
321
+ def obj(x):
322
+ s = 0.0
323
+ for f in eq_funcs:
324
+ try:
325
+ s += float(f(*x))**2
326
+ except:
327
+ s += 1e6
328
+ return s
329
+
330
+ result = differential_evolution(obj, bounds, maxiter=500, popsize=15, mutation=(0.5, 1), recombination=0.7)
331
+ TOLERANCE = 1e-3
332
+ logger.info(f"[GeometryEngine] Strategy 4: best residual = {result.fun:.2e} (tol={TOLERANCE})")
333
+ if result.fun < TOLERANCE:
334
+ res = {var: float(val) for var, val in zip(all_vars, result.x)}
335
+ logger.info("[GeometryEngine] Strategy 4 (global opt): SUCCESS.")
336
+ return {pid: [float(res.get(vx, 0)), float(res.get(vy, 0)), float(res.get(vz, 0))]
337
+ for pid, (vx, vy, vz) in point_vars.items()}
338
+ except Exception as e:
339
+ logger.error(f"[GeometryEngine] Strategy 4 threw exception: {e}")
340
+ return None
341
+
342
+ # ─── Result builder ──────────────────────────────────────────────────────
343
+
344
+ def _build_result(
345
+ self,
346
+ coords: Dict[str, List[float]],
347
+ polygon_order: List[str],
348
+ circles_meta: List[Dict],
349
+ segments_meta: List[List[str]],
350
+ lines_meta: List[List[str]],
351
+ rays_meta: List[List[str]],
352
+ pt_list: List[Point],
353
+ ) -> Dict[str, Any]:
354
+ """
355
+ Build structured result including drawing phases for the renderer.
356
+
357
+ drawing_phases:
358
+ Phase 1 — Base shape (main polygon)
359
+ Phase 2 — Auxiliary/derived points and segments
360
+ """
361
+ all_ids = [p.id for p in pt_list]
362
+
363
+ # 1. Infer/clean polygon_order
364
+ if not polygon_order:
365
+ # Fallback: use all declared point IDs sorted by conventional uppercase order.
366
+ # This is far safer than only looking for A/B/C/D.
367
+ base_pts = sorted(
368
+ all_ids,
369
+ key=lambda p: (string.ascii_uppercase.index(p) if p in string.ascii_uppercase else 100, p)
370
+ )
371
+ polygon_order = base_pts
372
+
373
+ base_ids = [pid for pid in polygon_order if pid in all_ids]
374
+ derived_ids = [pid for pid in all_ids if pid not in polygon_order]
375
+
376
+ # 2. Collect unique segments to avoid redundancy (AB == BA)
377
+ drawn_segments = set()
378
+
379
+ def add_segment(p1, p2, target_list):
380
+ if p1 == p2:
381
+ return
382
+ s = frozenset([p1, p2])
383
+ if s not in drawn_segments:
384
+ drawn_segments.add(s)
385
+ target_list.append([p1, p2])
386
+
387
+ # Phase 1: Main polygon boundary
388
+ phase1_segments = []
389
+ if len(base_ids) >= 2:
390
+ # Connect in sequence: A-B, B-C, etc.
391
+ for i in range(len(base_ids) - 1):
392
+ add_segment(base_ids[i], base_ids[i+1], phase1_segments)
393
+
394
+ # ONLY close the loop if we have 3 or more points (a real polygon)
395
+ if len(base_ids) > 2:
396
+ add_segment(base_ids[-1], base_ids[0], phase1_segments)
397
+
398
+ # Phase 2: Auxiliary segments from DSL
399
+ phase2_segments = []
400
+ for p1, p2 in segments_meta:
401
+ add_segment(p1, p2, phase2_segments)
402
+
403
+ drawing_phases = [
404
+ {
405
+ "phase": 1,
406
+ "label": "Hình cơ bản",
407
+ "points": base_ids,
408
+ "segments": phase1_segments,
409
+ }
410
+ ]
411
+ if derived_ids or phase2_segments:
412
+ drawing_phases.append({
413
+ "phase": 2,
414
+ "label": "Điểm và đoạn phụ",
415
+ "points": derived_ids,
416
+ "segments": phase2_segments,
417
+ })
418
+
419
+ return {
420
+ "coordinates": coords,
421
+ "polygon_order": polygon_order,
422
+ "circles": circles_meta,
423
+ "lines": lines_meta,
424
+ "rays": rays_meta,
425
+ "drawing_phases": drawing_phases,
426
+ }
solver/models.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import List, Dict, Union, Optional
3
+
4
+ class Point(BaseModel):
5
+ id: str
6
+ x: Optional[float] = None
7
+ y: Optional[float] = None
8
+ z: Optional[float] = None
9
+
10
+ class Constraint(BaseModel):
11
+ type: str # 'length', 'angle', 'parallel', etc.
12
+ targets: List[str]
13
+ value: Union[float, str]
tests/test_3d_solver.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from solver.dsl_parser import DSLParser
3
+ from solver.engine import GeometryEngine
4
+ from solver.models import Point, Constraint
5
+
6
+ def test_solve_square_pyramid():
7
+ """
8
+ Test solving for a square pyramid S.ABCD.
9
+ Base ABCD is a square with side 10.
10
+ Height SO = 15, where O is the center of ABCD.
11
+ """
12
+ dsl = """
13
+ POINT(A, 0, 0, 0)
14
+ POINT(B, 10, 0, 0)
15
+ POINT(C, 10, 10, 0)
16
+ POINT(D, 0, 10, 0)
17
+ POINT(S)
18
+ POINT(O)
19
+ MIDPOINT(M1, AB)
20
+ MIDPOINT(M2, AC)
21
+ SECTION(O, A, C, 0.5)
22
+ LENGTH(SO, 15)
23
+ PERPENDICULAR(SO, AC)
24
+ PERPENDICULAR(SO, AB)
25
+ PYRAMID(S_ABCD)
26
+ """
27
+ parser = DSLParser()
28
+ engine = GeometryEngine()
29
+
30
+ points, constraints = parser.parse(dsl)
31
+ result = engine.solve(points, constraints)
32
+
33
+ assert result is not None
34
+ coords = result["coordinates"]
35
+
36
+ # Check base points
37
+ assert coords["A"] == [0.0, 0.0, 0.0]
38
+ assert coords["B"] == [10.0, 0.0, 0.0]
39
+ assert coords["C"] == [10.0, 10.0, 0.0]
40
+ assert coords["D"] == [0.0, 10.0, 0.0]
41
+
42
+ # Check center O (should be (5, 5, 0))
43
+ assert coords["O"][0] == pytest.approx(5.0)
44
+ assert coords["O"][1] == pytest.approx(5.0)
45
+ assert coords["O"][2] == pytest.approx(0.0)
46
+
47
+ # Check apex S (should be (5, 5, 15) or (5, 5, -15))
48
+ assert coords["S"][0] == pytest.approx(5.0)
49
+ assert coords["S"][1] == pytest.approx(5.0)
50
+ assert abs(coords["S"][2]) == pytest.approx(15.0)
51
+
52
+ def test_solve_prism():
53
+ """
54
+ Triangular prism ABC_DEF.
55
+ Base ABC is right triangle at A. AB=3, AC=4.
56
+ Height AD=10.
57
+ """
58
+ dsl = """
59
+ POINT(A, 0, 0, 0)
60
+ POINT(B, 3, 0, 0)
61
+ POINT(C, 0, 4, 0)
62
+ POINT(D)
63
+ POINT(E)
64
+ POINT(F)
65
+ LENGTH(AD, 10)
66
+ PERPENDICULAR(AD, AB)
67
+ PERPENDICULAR(AD, AC)
68
+ PRISM(ABC_DEF)
69
+ """
70
+ parser = DSLParser()
71
+ engine = GeometryEngine()
72
+
73
+ points, constraints = parser.parse(dsl)
74
+ result = engine.solve(points, constraints)
75
+
76
+ assert result is not None
77
+ coords = result["coordinates"]
78
+
79
+ # D should be (0, 0, 10)
80
+ assert coords["D"][0] == pytest.approx(0.0)
81
+ assert coords["D"][1] == pytest.approx(0.0)
82
+ assert abs(coords["D"][2]) == pytest.approx(10.0)
83
+
84
+ if __name__ == "__main__":
85
+ pytest.main([__file__])
tests/test_advanced_geometry.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import asyncio
3
+ import logging
4
+ from solver.dsl_parser import DSLParser
5
+ from solver.engine import GeometryEngine
6
+
7
+ logging.basicConfig(level=logging.DEBUG)
8
+
9
+ @pytest.mark.asyncio
10
+ async def test_section_internal():
11
+ print("\n--- Test: Section Point (Internal AE=2/3 AC) ---")
12
+ dsl = """
13
+ POINT(A)
14
+ POINT(B)
15
+ POINT(C)
16
+ LENGTH(AB, 6)
17
+ LENGTH(BC, 6)
18
+ ANGLE(B, 90)
19
+ SECTION(E, A, C, 0.6667)
20
+ """
21
+ parser = DSLParser()
22
+ engine = GeometryEngine()
23
+
24
+ pts, constraints = parser.parse(dsl)
25
+ result = engine.solve(pts, constraints)
26
+
27
+ if result:
28
+ coords = result['coordinates']
29
+ print(f" A: {coords['A']}")
30
+ print(f" C: {coords['C']}")
31
+ print(f" E: {coords['E']}")
32
+
33
+ # Verify AE = 0.6667 * AC
34
+ import math
35
+ def dist(p1, p2): return math.sqrt((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2)
36
+
37
+ d_ac = dist(coords['A'], coords['C'])
38
+ d_ae = dist(coords['A'], coords['E'])
39
+ ratio = d_ae / d_ac
40
+ print(f" Calculated Ratio AE/AC: {ratio:.4f} (Expected: 0.6667)")
41
+ assert abs(ratio - 0.6667) < 1e-4
42
+ else:
43
+ print(" ❌ Solve failed")
44
+
45
+ @pytest.mark.asyncio
46
+ async def test_section_external():
47
+ print("\n--- Test: Section Point (External AE=2*AC) ---")
48
+ dsl = """
49
+ POINT(A)
50
+ POINT(C)
51
+ LENGTH(AC, 5)
52
+ SECTION(E, A, C, 2.0)
53
+ """
54
+ parser = DSLParser()
55
+ engine = GeometryEngine()
56
+
57
+ pts, constraints = parser.parse(dsl)
58
+ result = engine.solve(pts, constraints)
59
+
60
+ if result:
61
+ coords = result['coordinates']
62
+ print(f" A: {coords['A']}")
63
+ print(f" C: {coords['C']}")
64
+ print(f" E: {coords['E']}")
65
+
66
+ import math
67
+ def dist(p1, p2): return math.sqrt((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2)
68
+ d_ac = dist(coords['A'], coords['C'])
69
+ d_ae = dist(coords['A'], coords['E'])
70
+ print(f" AE: {d_ae}, AC: {d_ac}, Ratio: {d_ae/d_ac}")
71
+ assert abs(d_ae/d_ac - 2.0) < 1e-4
72
+ else:
73
+ print(" ❌ Solve failed")
74
+
75
+ @pytest.mark.asyncio
76
+ async def test_line_ray_metadata():
77
+ print("\n--- Test: Line and Ray Metadata ---")
78
+ dsl = """
79
+ POINT(A)
80
+ POINT(B)
81
+ LINE(A, B)
82
+ RAY(A, B)
83
+ """
84
+ parser = DSLParser()
85
+ engine = GeometryEngine()
86
+
87
+ pts, constraints = parser.parse(dsl)
88
+ result = engine.solve(pts, constraints)
89
+
90
+ if result:
91
+ print(f" Lines: {result.get('lines')}")
92
+ print(f" Rays: {result.get('rays')}")
93
+ assert ['A', 'B'] in result.get('lines', [])
94
+ assert ['A', 'B'] in result.get('rays', [])
95
+ print(" ✅ Metadata present")
96
+ else:
97
+ print(" ❌ Solve failed")
98
+
99
+ if __name__ == "__main__":
100
+ asyncio.run(test_section_internal())
101
+ asyncio.run(test_section_external())
102
+ asyncio.run(test_line_ray_metadata())
tests/test_api_full_suite.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import httpx
3
+ import time
4
+ import asyncio
5
+ import pytest
6
+ import logging
7
+ import json
8
+
9
+ # Configuration
10
+ BASE_URL = os.getenv("TEST_BASE_URL", "http://localhost:8000")
11
+ USER_ID = os.getenv("TEST_USER_ID")
12
+ SESSION_ID = os.getenv("TEST_SESSION_ID")
13
+
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ QUERIES = [
18
+ {
19
+ "id": "Q1",
20
+ "text": "Cho hình chữ nhật ABCD có AB bằng 5 và AD bằng 10",
21
+ "expect_pts": ["A", "B", "C", "D"],
22
+ "expect_phases": 1,
23
+ },
24
+ {
25
+ "id": "Q2",
26
+ "text": "Tam giác ABC có AB=6, BC=8, AC=10",
27
+ "expect_pts": ["A", "B", "C"],
28
+ "expect_phases": 1,
29
+ },
30
+ {
31
+ "id": "Q3",
32
+ "text": "Cho hình chữ nhật ABCD có AB bằng 10 và AD bằng 20. Vẽ điểm M là trung điểm của AB và N là trung điểm của AD.",
33
+ "expect_pts": ["A", "B", "C", "D", "M", "N"],
34
+ "expect_phases": 2,
35
+ },
36
+ {
37
+ "id": "Q4",
38
+ "text": "Cho hình thang ABCD vuông tại A và D. AB=4, CD=8, AD=5.",
39
+ "expect_pts": ["A", "B", "C", "D"],
40
+ "expect_phases": 1,
41
+ },
42
+ {
43
+ "id": "Q5",
44
+ "text": "Cho hình vuông ABCD có cạnh bằng 6.",
45
+ "expect_pts": ["A", "B", "C", "D"],
46
+ "expect_phases": 1,
47
+ },
48
+ {
49
+ "id": "Q6",
50
+ "text": "Cho tam giác ABC vuông tại A. AB=3, AC=4. Vẽ đường cao AH.",
51
+ "expect_pts": ["A", "B", "C", "H"],
52
+ "expect_phases": 2,
53
+ },
54
+ {
55
+ "id": "Q7",
56
+ "text": "Cho hình thoi ABCD có cạnh bằng 5 và góc A bằng 60 độ.",
57
+ "expect_pts": ["A", "B", "C", "D"],
58
+ "expect_phases": 1,
59
+ },
60
+ {
61
+ "id": "Q8",
62
+ "text": "Cho đường tròn tâm O bán kính bằng 7.",
63
+ "expect_pts": ["O"],
64
+ "expect_phases": 1,
65
+ },
66
+ {
67
+ "id": "Q9",
68
+ "text": "Cho hình bình hành ABCD có AB=8, AD=6. Gọi E là trung điểm của CD. Vẽ đoạn thẳng AE.",
69
+ "expect_pts": ["A", "B", "C", "D", "E"],
70
+ "expect_phases": 2,
71
+ },
72
+ {
73
+ "id": "Q10-Step1",
74
+ "text": "Cho hình chữ nhật ABCD có AB=10, AD=5.",
75
+ "expect_pts": ["A", "B", "C", "D"],
76
+ "expect_phases": 1,
77
+ },
78
+ {
79
+ "id": "Q11-Video",
80
+ "text": "Cho tam giác ABC đều cạnh 5. Vẽ đường tròn ngoại tiếp tam giác.",
81
+ "expect_pts": ["A", "B", "C"],
82
+ "expect_phases": 2,
83
+ "request_video": True
84
+ },
85
+ {
86
+ "id": "Q12-3D",
87
+ "text": "Cho hình chóp S.ABCD có đáy ABCD là hình vuông cạnh 10, đường cao SO=15 với O là tâm đáy.",
88
+ "expect_pts": ["S", "A", "B", "C", "D", "O"],
89
+ "expect_phases": 2,
90
+ }
91
+ ]
92
+
93
+ Q10_FOLLOW_UP = {
94
+ "id": "Q10-Step2",
95
+ "text": "Vẽ thêm đường chéo AC.",
96
+ "expect_pts": ["A", "B", "C", "D"],
97
+ "expect_phases": 2,
98
+ }
99
+
100
+ test_stats = []
101
+
102
+ async def run_single_api_query(client, q, headers):
103
+ print(f"\n🚀 [RUNNING] {q['id']}: {q['text']}")
104
+ start_time = time.time()
105
+
106
+ # 1. Submit Request
107
+ payload = {
108
+ "text": q["text"],
109
+ "request_video": q.get("request_video", False)
110
+ }
111
+
112
+ try:
113
+ if q.get("isolate", True):
114
+ # Create a fresh session for isolation
115
+ session_resp = await client.post("/api/v1/sessions", headers=headers)
116
+ if session_resp.status_code != 200:
117
+ return {"id": q["id"], "query": q["text"], "success": False, "error": f"Session creation failed: {session_resp.text}"}
118
+ session_id = session_resp.json()["id"]
119
+ else:
120
+ session_id = q.get("session_id", SESSION_ID)
121
+
122
+ res = await client.post(f"/api/v1/sessions/{session_id}/solve", json=payload, headers=headers)
123
+ if res.status_code != 200:
124
+ print(f" ❌ FAILED: Status {res.status_code} - {res.text}")
125
+ return {"id": q["id"], "query": q["text"], "success": False, "error": f"HTTP {res.status_code}: {res.text}"}
126
+
127
+ job_id = res.json()["job_id"]
128
+ print(f" ✅ Job Created: {job_id}")
129
+
130
+ # 2. Polling result
131
+ max_attempts = 45 # Increased for video rendering
132
+ result_data = None
133
+ for i in range(max_attempts):
134
+ await asyncio.sleep(4)
135
+ res = await client.get(f"/api/v1/solve/{job_id}", headers=headers)
136
+ data = res.json()
137
+ status = data.get("status")
138
+ print(f" - Polling ({i+1}): {status}")
139
+
140
+ if status == "success":
141
+ result_data = data["result"]
142
+ break
143
+ if status == "error":
144
+ print(f" ❌ ERROR: {data.get('result', {}).get('error')}")
145
+ return {"id": q["id"], "query": q["text"], "success": False, "error": data.get("result", {}).get("error")}
146
+
147
+ if i == max_attempts - 1:
148
+ print(" ❌ TIMEOUT")
149
+ return {"id": q["id"], "query": q["text"], "success": False, "error": "Timeout"}
150
+
151
+ # 3. Strict Validation
152
+ elapsed = time.time() - start_time
153
+ errors = []
154
+
155
+ # Validation: Coordinates
156
+ coords = result_data.get("coordinates", {})
157
+ for pt in q["expect_pts"]:
158
+ if pt not in coords:
159
+ errors.append(f"Missing point {pt}")
160
+
161
+ # Validation: Non-zero coords (generic check)
162
+ # Only fail if there are MULTIPLE points and all are at origin.
163
+ # A single point (like a circle center) at origin is perfectly valid.
164
+ if coords and len(coords) > 1 and all(v == [0,0,0] for v in coords.values()):
165
+ errors.append("All points are at [0,0,0]")
166
+
167
+ # Validation: Drawing Phases
168
+ phases = result_data.get("drawing_phases", [])
169
+ if len(phases) < q["expect_phases"]:
170
+ errors.append(f"Expected {q['expect_phases']} phases, got {len(phases)}")
171
+
172
+ # Validation: Video URL if requested
173
+ if q.get("request_video") and not result_data.get("video_url"):
174
+ # We allow video fail if it's environment issue, but log it
175
+ print(" ⚠️ Video requested but no URL found (Expected in some test envs)")
176
+ # errors.append("Video URL missing")
177
+
178
+ if errors:
179
+ print(f" ❌ VALIDATION FAILED: {', '.join(errors)}")
180
+ return {"id": q["id"], "query": q["text"], "success": False, "error": "; ".join(errors), "elapsed": elapsed, "result": result_data}
181
+
182
+ print(f" ✅ PASS ({elapsed:.2f}s)")
183
+ return {"id": q['id'], "query": q["text"], "success": True, "elapsed": elapsed, "job_id": job_id, "result": result_data}
184
+
185
+ except Exception as e:
186
+ print(f" ❌ EXCEPTION: {str(e)}")
187
+ return {"id": q["id"], "query": q["text"], "success": False, "error": str(e)}
188
+
189
+ @pytest.mark.asyncio
190
+ async def test_full_api_suite():
191
+ if not USER_ID or not SESSION_ID:
192
+ pytest.fail("TEST_USER_ID and TEST_SESSION_ID must be set")
193
+
194
+ headers = {"Authorization": f"Test {USER_ID}"}
195
+
196
+ async with httpx.AsyncClient(base_url=BASE_URL, timeout=60.0) as client:
197
+ # Run standard queries
198
+ import uuid
199
+ for q in QUERIES:
200
+ if q["id"] == "Q10-Step1": continue
201
+ # Isolated by default
202
+ res = await run_single_api_query(client, q, headers)
203
+ test_stats.append(res)
204
+
205
+ # Run Multi-turn Q10
206
+ print("\n--- Testing Multi-turn API Flow (Q10) ---")
207
+ # Create a shared session for Q10
208
+ shared_session_resp = await client.post("/api/v1/sessions", headers=headers)
209
+ shared_session = shared_session_resp.json()["id"]
210
+
211
+ q10_1 = next(q for q in QUERIES if q["id"] == "Q10-Step1")
212
+ q10_1["session_id"] = shared_session
213
+ q10_1["isolate"] = False
214
+ res10_1 = await run_single_api_query(client, q10_1, headers)
215
+ test_stats.append(res10_1)
216
+
217
+ if res10_1["success"]:
218
+ Q10_FOLLOW_UP["session_id"] = shared_session
219
+ Q10_FOLLOW_UP["isolate"] = False
220
+ res10_2 = await run_single_api_query(client, Q10_FOLLOW_UP, headers)
221
+
222
+ # Additional check for Q10-Step2: check if DSL contains combined logic
223
+ if res10_2["success"]:
224
+ dsl = res10_2["result"].get("geometry_dsl", "")
225
+ if "POLYGON_ORDER" not in dsl or "SEGMENT" not in dsl:
226
+ res10_2["success"] = False
227
+ res10_2["error"] = "DSL did not merge history correctly"
228
+
229
+ test_stats.append(res10_2)
230
+
231
+ # Save Results to JSON for the runner script to generate Markdown
232
+ with open("temp_suite_results.json", "w") as f:
233
+ json.dump(test_stats, f)
234
+
235
+ if __name__ == "__main__":
236
+ import asyncio
237
+ asyncio.run(test_full_api_suite())
tests/test_api_metadata_real.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import asyncio
3
+ import uuid
4
+ import time
5
+ from app.routers.solve import process_session_job
6
+ from app.models.schemas import SolveRequest
7
+ from app.supabase_client import get_supabase
8
+
9
+ @pytest.mark.asyncio
10
+ async def test_metadata_persistence():
11
+ session_id = "81f87517-88f2-40bd-96a9-7b34f1d14b6a"
12
+ user_id = "8cd3adb0-7964-4575-949c-d0cadcd8b679"
13
+ job_id = str(uuid.uuid4())
14
+
15
+ print(f"🚀 Starting sub-pipeline test for job {job_id}...")
16
+
17
+ request = SolveRequest(
18
+ text="Cho hình chữ nhật ABCD có AB=10, AD=20. Vẽ đường thẳng d đi qua A và B.",
19
+ request_video=False
20
+ )
21
+
22
+ # Trigger the process_session_job directly
23
+ await process_session_job(job_id, session_id, request, user_id)
24
+
25
+ print("⏳ Waiting for database sync (3s)...")
26
+ await asyncio.sleep(3)
27
+
28
+ # Verify the results in Supabase
29
+ supabase = get_supabase()
30
+ res = supabase.table("messages") \
31
+ .select("metadata, created_at") \
32
+ .eq("session_id", session_id) \
33
+ .eq("role", "assistant") \
34
+ .order("created_at", desc=True) \
35
+ .limit(1) \
36
+ .execute()
37
+
38
+ if not res.data:
39
+ print("❌ FAIL: No assistant message found in database.")
40
+ return
41
+
42
+ metadata = res.data[0].get("metadata", {})
43
+ required_fields = ["job_id", "coordinates", "polygon_order", "drawing_phases", "circles", "lines", "rays"]
44
+ missing = [f for f in required_fields if f not in metadata]
45
+
46
+ if not missing:
47
+ print("✅ SUCCESS: All metadata fields (including lines/rays) persisted correctly.")
48
+ print(f" job_id: {metadata.get('job_id')}")
49
+ print(f" polygon_order: {metadata.get('polygon_order')}")
50
+ print(f" lines: {metadata.get('lines')}")
51
+ print(f" phases: {len(metadata.get('drawing_phases', []))}")
52
+ else:
53
+ print(f"❌ FAIL: Missing fields in metadata: {missing}")
54
+
55
+ if __name__ == "__main__":
56
+ asyncio.run(test_metadata_persistence())
tests/test_api_real_e2e.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import httpx
3
+ import time
4
+ import pytest
5
+ import logging
6
+
7
+ # Configuration from environment
8
+ BASE_URL = os.getenv("TEST_BASE_URL", "http://localhost:8000")
9
+ USER_ID = os.getenv("TEST_USER_ID")
10
+ SESSION_ID = os.getenv("TEST_SESSION_ID")
11
+
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+
15
+ @pytest.mark.asyncio
16
+ async def test_api_e2e_flow():
17
+ if not USER_ID or not SESSION_ID:
18
+ pytest.fail("TEST_USER_ID and TEST_SESSION_ID must be set")
19
+
20
+ auth_headers = {"Authorization": f"Test {USER_ID}"}
21
+
22
+ async with httpx.AsyncClient(base_url=BASE_URL, timeout=30.0) as client:
23
+ # 1. Health check
24
+ print("\n[1/3] Checking API Health...")
25
+ res = await client.get("/")
26
+ assert res.status_code == 200
27
+ assert "running" in res.json()["message"].lower()
28
+ print(" ✅ Health check passed")
29
+
30
+ # 2. Submit Solve Request
31
+ print(f"\n[2/3] Submitting solve request for session {SESSION_ID}...")
32
+ payload = {
33
+ "text": "Cho hình chữ nhật ABCD có AB=5, AD=10. Tính diện tích.",
34
+ "request_video": False
35
+ }
36
+ res = await client.post(f"/api/v1/sessions/{SESSION_ID}/solve", json=payload, headers=auth_headers)
37
+
38
+ if res.status_code != 200:
39
+ print(f" ❌ FAILED: {res.text}")
40
+ assert res.status_code == 200
41
+
42
+ data = res.json()
43
+ job_id = data["job_id"]
44
+ assert job_id is not None
45
+ print(f" ✅ Request accepted. Job ID: {job_id}")
46
+
47
+ # 3. Polling Job Status
48
+ print("\n[3/3] Polling job status...")
49
+ max_attempts = 15
50
+ for i in range(max_attempts):
51
+ time.sleep(2) # Simple sleep between polls
52
+ res = await client.get(f"/api/v1/solve/{job_id}")
53
+ assert res.status_code == 200
54
+ job_data = res.json()
55
+ status = job_data["status"]
56
+ print(f" Attempt {i+1}: Status = {status}")
57
+
58
+ if status == "success":
59
+ print(" ✅ SUCCESS: API pipeline completed successfully.")
60
+ result = job_data.get("result", {})
61
+ assert "coordinates" in result
62
+ assert "geometry_dsl" in result
63
+ return
64
+
65
+ if status == "error":
66
+ error_msg = job_data.get("result", {}).get("error", "Unknown error")
67
+ pytest.fail(f"Job failed with error: {error_msg}")
68
+
69
+ if i == max_attempts - 1:
70
+ pytest.fail("Timeout waiting for job completion")
71
+
72
+ if __name__ == "__main__":
73
+ # This allows running the script directly if needed
74
+ import asyncio
75
+ asyncio.run(test_api_e2e_flow())
tests/test_direct_task.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ from dotenv import load_dotenv
5
+
6
+ # Ensure we can import from backend
7
+ sys.path.append(os.getcwd())
8
+
9
+ from app.supabase_client import get_supabase
10
+ from worker.tasks import render_geometry_video
11
+
12
+ def test_celery_task_directly():
13
+ load_dotenv()
14
+
15
+ # Mock data for a square
16
+ data = {
17
+ "session_id": "88888888-8888-8888-8888-888888888888", # Fake uuid
18
+ "coordinates": {
19
+ "A": [0, 0],
20
+ "B": [5, 0],
21
+ "C": [5, 5],
22
+ "D": [0, 5]
23
+ },
24
+ "polygon_order": ["A", "B", "C", "D"],
25
+ "drawing_phases": [
26
+ {
27
+ "phase": 1,
28
+ "label": "Base",
29
+ "points": ["A", "B", "C", "D"],
30
+ "segments": [["A","B"],["B","C"],["C","D"],["D","A"]]
31
+ }
32
+ ],
33
+ "semantic_analysis": "Test squere video rendering."
34
+ }
35
+
36
+ job_id = f"manual-direct-test-{int(os.time.time()) if hasattr(os, 'time') else 123}"
37
+ # Just use a static ID or similar
38
+ import time
39
+ job_id = f"manual-test-{int(time.time())}"
40
+
41
+ print(f"🚀 Running render_geometry_video directly for job {job_id}...")
42
+
43
+ try:
44
+ # We need to mock Supabase calls if we don't want to actually hit the DB,
45
+ # but here we WANT to test the real task logic.
46
+ # This will fail on DB update if job_id doesn't exist in 'jobs' table.
47
+ # So let's create a dummy job first.
48
+ supabase = get_supabase()
49
+ supabase.table("jobs").insert({
50
+ "id": job_id,
51
+ "user_id": None,
52
+ "status": "processing",
53
+ "type": "solve"
54
+ }).execute()
55
+
56
+ # Run the task function directly (not via .delay)
57
+ video_url = render_geometry_video(job_id, data)
58
+
59
+ if video_url:
60
+ print(f"✅ SUCCESS! Video URL: {video_url}")
61
+ else:
62
+ print("❌ FAIL: No video URL returned.")
63
+
64
+ except NameError as e:
65
+ print(f"❌ NameError Caught: {e}")
66
+ except Exception as e:
67
+ print(f"❌ Error during manual task execution: {e}")
68
+
69
+ if __name__ == "__main__":
70
+ test_celery_task_directly()
tests/test_full_pipeline.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ import json
4
+ import os
5
+ import math
6
+ import time
7
+ from dotenv import load_dotenv
8
+
9
+ from app.logging_setup import setup_application_logging
10
+ setup_application_logging()
11
+ logging.getLogger("agents").setLevel(logging.DEBUG)
12
+ logging.getLogger("solver").setLevel(logging.DEBUG)
13
+ logging.getLogger("app").setLevel(logging.DEBUG)
14
+
15
+ from agents.orchestrator import Orchestrator
16
+ from app.supabase_client import get_supabase
17
+
18
+ QUERIES = [
19
+ {
20
+ "id": "Q1",
21
+ "text": "Cho hình chữ nhật ABCD có AB bằng 5 và AD bằng 10",
22
+ "expect_pts": ["A", "B", "C", "D"],
23
+ "expect_phases": 1,
24
+ },
25
+ {
26
+ "id": "Q2",
27
+ "text": "Tam giác ABC có AB=6, BC=8, AC=10",
28
+ "expect_pts": ["A", "B", "C"],
29
+ "expect_phases": 1,
30
+ },
31
+ {
32
+ "id": "Q3",
33
+ "text": "Cho hình chữ nhật ABCD có AB bằng 10 và AD bằng 20. Vẽ điểm M là trung điểm của AB và N là trung điểm của AD.",
34
+ "expect_pts": ["A", "B", "C", "D", "M", "N"],
35
+ "expect_phases": 2,
36
+ },
37
+ {
38
+ "id": "Q4",
39
+ "text": "Cho hình thang ABCD vuông tại A và D. AB=4, CD=8, AD=5.",
40
+ "expect_pts": ["A", "B", "C", "D"],
41
+ "expect_phases": 1,
42
+ },
43
+ {
44
+ "id": "Q5",
45
+ "text": "Cho hình vuông ABCD có cạnh bằng 6.",
46
+ "expect_pts": ["A", "B", "C", "D"],
47
+ "expect_phases": 1,
48
+ },
49
+ {
50
+ "id": "Q6",
51
+ "text": "Cho tam giác ABC vuông tại A. AB=3, AC=4. Vẽ đường cao AH.",
52
+ "expect_pts": ["A", "B", "C", "H"],
53
+ "expect_phases": 2,
54
+ },
55
+ {
56
+ "id": "Q7",
57
+ "text": "Cho hình thoi ABCD có cạnh bằng 5 và góc A bằng 60 độ.",
58
+ "expect_pts": ["A", "B", "C", "D"],
59
+ "expect_phases": 1,
60
+ },
61
+ {
62
+ "id": "Q8",
63
+ "text": "Cho đường tròn tâm O bán kính bằng 7.",
64
+ "expect_pts": ["O"],
65
+ "expect_phases": 1,
66
+ },
67
+ {
68
+ "id": "Q9",
69
+ "text": "Cho hình bình hành ABCD có AB=8, AD=6. Gọi E là trung điểm của CD. Vẽ đoạn thẳng AE.",
70
+ "expect_pts": ["A", "B", "C", "D", "E"],
71
+ "expect_phases": 2,
72
+ },
73
+ {
74
+ "id": "Q10-Step1",
75
+ "text": "Cho hình chữ nhật ABCD có AB=10, AD=5.",
76
+ "expect_pts": ["A", "B", "C", "D"],
77
+ "expect_phases": 1,
78
+ },
79
+ {
80
+ "id": "Q11-Video",
81
+ "text": "Cho tam giác ABC đều cạnh 5. Vẽ đường tròn ngoại tiếp tam giác.",
82
+ "expect_pts": ["A", "B", "C"],
83
+ "expect_phases": 2,
84
+ "request_video": True
85
+ }
86
+ ]
87
+
88
+ # Q10-Step2 is a follow-up to Q10-Step1
89
+ Q10_FOLLOW_UP = {
90
+ "id": "Q10-Step2",
91
+ "text": "Vẽ thêm đường chéo AC.",
92
+ "expect_pts": ["A", "B", "C", "D"],
93
+ "expect_phases": 2, # Main polygon + diagonal segment
94
+ }
95
+
96
+ def dist(p1, p2):
97
+ return math.sqrt((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2)
98
+
99
+ async def run_query(orchestrator, q, history=None):
100
+ print(f"\n{'='*60}")
101
+ print(f"[{q['id']}] {q['text']}")
102
+ if history:
103
+ print(f" (With history context of {len(history)} messages)")
104
+ if q.get("request_video"):
105
+ print(" 🎥 VIDEO RENDERING REQUESTED")
106
+ print('='*60)
107
+ try:
108
+ result = await orchestrator.run(
109
+ text=q["text"],
110
+ job_id=f"test-{q['id']}-{int(time.time())}",
111
+ request_video=q.get("request_video", False),
112
+ history=history,
113
+ )
114
+
115
+ if "error" in result:
116
+ print(f" ❌ PIPELINE ERROR: {result['error']}")
117
+ return None
118
+
119
+ # Check 1: semantic_analysis != original query
120
+ analysis = result.get("semantic_analysis", "")
121
+ if analysis.strip() == q["text"].strip():
122
+ print(f" ❌ FAIL: semantic_analysis is identical to input query")
123
+ else:
124
+ print(f" ✅ semantic_analysis: {analysis[:100]}...")
125
+
126
+ # Check 2: all expected points are in coordinates
127
+ coords = result.get("coordinates", {})
128
+ missing = [pt for pt in q["expect_pts"] if pt not in coords]
129
+ if missing:
130
+ print(f" ❌ FAIL: Missing points in coordinates: {missing}")
131
+ else:
132
+ print(f" ✅ All expected points present: {list(coords.keys())}")
133
+
134
+ # Check 4: drawing_phases
135
+ phases = result.get("drawing_phases", [])
136
+ if len(phases) >= q["expect_phases"]:
137
+ print(f" ✅ drawing_phases: {len(phases)} phase(s)")
138
+ else:
139
+ print(f" ❌ FAIL: expected {q['expect_phases']} drawing phase(s), got {len(phases)}")
140
+
141
+ # Check 5: Video Polling (if requested)
142
+ if q.get("request_video"):
143
+ job_id = result.get("job_id")
144
+ if not job_id:
145
+ print(" ❌ FAIL: Video requested but no job_id returned")
146
+ return None
147
+
148
+ print(f" ⏳ Waiting for Manim video (job_id: {job_id})...")
149
+ supabase = get_supabase()
150
+ max_retries = 24 # 2 minutes (24 * 5s)
151
+ success = False
152
+ for _ in range(max_retries):
153
+ job_res = supabase.table("jobs").select("*").eq("id", job_id).execute()
154
+ if job_res.data:
155
+ job_data = job_res.data[0]
156
+ status = job_data.get("status")
157
+ if status == "success":
158
+ video_url = job_data.get("result", {}).get("video_url")
159
+ if video_url:
160
+ print(f" ✅ VIDEO READY: {video_url}")
161
+ result["video_url"] = video_url
162
+ success = True
163
+ break
164
+ else:
165
+ print(" ❌ FAIL: Job success but no video_url in result")
166
+ return None
167
+ elif status == "failed":
168
+ print(f" ❌ FAIL: Manim worker job failed")
169
+ return None
170
+ await asyncio.sleep(5)
171
+
172
+ if not success:
173
+ print(" ❌ FAIL: Timeout waiting for video rendering")
174
+ return None
175
+
176
+ dsl = result.get('geometry_dsl', '')
177
+ print(f" DSL ({len(dsl.splitlines())} lines):\n{dsl}")
178
+
179
+ # Specific check for Q10-Step2: must contain BOTH ABCD and AC segment
180
+ if q["id"] == "Q10-Step2":
181
+ if "POLYGON_ORDER(A, B, C, D)" in dsl and "SEGMENT(A, C)" in dsl:
182
+ print(f" ✅ Multi-turn Success: DSL merged correctly.")
183
+ else:
184
+ print(f" ❌ Multi-turn Fail: DSL missing component.")
185
+
186
+ return result
187
+
188
+ except Exception as e:
189
+ import traceback
190
+ print(f" ❌ EXCEPTION: {type(e).__name__}: {e}")
191
+ traceback.print_exc()
192
+ return None
193
+
194
+ async def main():
195
+ load_dotenv()
196
+ orchestrator = Orchestrator()
197
+
198
+ results = []
199
+ # Run Q1 to Q9 and Q11 (Video)
200
+ queries_to_test = QUERIES[:-1] + [QUERIES[-1]] # All except Step1
201
+
202
+ # Actually let's just iterate over all and handle Q10 special
203
+ for q in QUERIES:
204
+ if q["id"] == "Q10-Step1":
205
+ continue
206
+ res = await run_query(orchestrator, q)
207
+ results.append((q["id"], res is not None))
208
+
209
+ # Run Q10 Flow (Multi-turn)
210
+ print("\n--- Starting Multi-turn Flow (Q10) ---")
211
+ q10_1 = next(q for q in QUERIES if q["id"] == "Q10-Step1")
212
+ res10_1 = await run_query(orchestrator, q10_1)
213
+ results.append((q10_1["id"], res10_1 is not None))
214
+
215
+ if res10_1:
216
+ # Construct message history to pass to step 2
217
+ history = [
218
+ {"role": "user", "content": q10_1["text"]},
219
+ {
220
+ "role": "assistant",
221
+ "content": res10_1["semantic_analysis"],
222
+ "metadata": {
223
+ "geometry_dsl": res10_1["geometry_dsl"],
224
+ "coordinates": res10_1["coordinates"]
225
+ }
226
+ }
227
+ ]
228
+ res10_2 = await run_query(orchestrator, Q10_FOLLOW_UP, history=history)
229
+ results.append((Q10_FOLLOW_UP["id"], res10_2 is not None))
230
+
231
+ print(f"\n{'='*60}")
232
+ print("SUMMARY:")
233
+ for qid, ok in results:
234
+ print(f" [{qid}] {'✅ PASS' if ok else '❌ FAIL'}")
235
+
236
+ if __name__ == "__main__":
237
+ asyncio.run(main())
tests/test_openrouter.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import httpx
3
+ import json
4
+ import time
5
+ from dotenv import load_dotenv
6
+
7
+ # Load môi trường từ backend/.env
8
+ load_dotenv(dotenv_path="./backend/.env")
9
+
10
+ MODELS = [
11
+ "nvidia/nemotron-3-super-120b-a12b:free",
12
+ "meta-llama/llama-3.3-70b-instruct:free",
13
+ "openai/gpt-oss-120b:free",
14
+ "z-ai/glm-4.5-air:free",
15
+ "minimax/minimax-m2.5:free",
16
+ "google/gemma-4-26b-a4b-it:free",
17
+ "google/gemma-4-31b-it:free",
18
+ ]
19
+
20
+ PROMPT = "Cho hình chữ nhật ABCD có AB bằng 5 và AD bằng 10. Gọi E là điểm nằm trong đoạn CD sao cho CE = 2ED. Vẽ đoạn thẳng AE. Vẽ thêm P là điểm nằm trên đường thẳng BC sao cho BP = 2PC, tính chu vi tam giác PEA"
21
+
22
+ def test_models():
23
+ api_key = os.getenv("OPENROUTER_API_KEY_1")
24
+ base_url = "https://openrouter.ai/api/v1/chat/completions"
25
+
26
+ if not api_key:
27
+ print("❌ Lỗi: Không tìm thấy OPENROUTER_API_KEY trong file .env")
28
+ return
29
+
30
+ print("🚀 Bắt đầu benchmark các model OpenRouter...")
31
+ print(f"📝 Prompt: {PROMPT}\n")
32
+
33
+ results = []
34
+
35
+ for model in MODELS:
36
+ print(f"📡 Đang gọi model: {model}...", end="", flush=True)
37
+
38
+ headers = {
39
+ "Authorization": f"Bearer {api_key}",
40
+ "Content-Type": "application/json",
41
+ "HTTP-Referer": "https://mathsolver.io",
42
+ "X-Title": "MathSolver Benchmark Tool"
43
+ }
44
+
45
+ payload = {
46
+ "model": model,
47
+ "messages": [{"role": "user", "content": PROMPT}]
48
+ }
49
+
50
+ start_time = time.time()
51
+ try:
52
+ with httpx.Client(timeout=60.0) as client:
53
+ response = client.post(base_url, headers=headers, json=payload)
54
+ response.raise_for_status()
55
+
56
+ duration = time.time() - start_time
57
+ data = response.json()
58
+ answer = data['choices'][0]['message']['content']
59
+
60
+ results.append({
61
+ "model": model,
62
+ "duration": duration,
63
+ "answer": answer,
64
+ "status": "success"
65
+ })
66
+ print(f" ✅ DONE ({duration:.2f}s)")
67
+
68
+ except Exception as e:
69
+ duration = time.time() - start_time
70
+ print(f" ❌ FAILED ({duration:.2f}s)")
71
+ results.append({
72
+ "model": model,
73
+ "duration": duration,
74
+ "error": str(e),
75
+ "status": "error"
76
+ })
77
+
78
+ print("\n" + "="*80)
79
+ print("📊 BÁO CÁO CHI TIẾT BENCHMARK")
80
+ print("="*80)
81
+
82
+ for res in results:
83
+ print(f"\n🔹 MODEL: {res['model']}")
84
+ print(f"⏱ Thời gian: {res['duration']:.2f} giây")
85
+ if res['status'] == "success":
86
+ print(f"🤖 Phản hồi:\n{res['answer']}")
87
+ else:
88
+ print(f"❌ Lỗi: {res.get('error')}")
89
+ print("-" * 40)
90
+
91
+ if __name__ == "__main__":
92
+ test_models()
tests/test_real_llm.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import asyncio
3
+ import os
4
+ import logging
5
+ from dotenv import load_dotenv
6
+ from app.llm_client import get_llm_client
7
+
8
+ # Setup logging to see the fallback process
9
+ logging.basicConfig(level=logging.INFO)
10
+ load_dotenv()
11
+
12
+ @pytest.mark.asyncio
13
+ async def test_real_llm():
14
+ client = get_llm_client()
15
+
16
+ print("\n--- Testing LLM Call (Complex Prompt) ---")
17
+ try:
18
+ content = await client.chat_completions_create(
19
+ messages=[
20
+ {"role": "system", "content": "You are a Geometry Expert. Formulate a step-by-step reasoning for calculating the distance between two points M and N where M is the midpoint of AB (len=10) and N is the midpoint of AD (len=20) in a rectangle ABCD. Use LaTeX for formulas."},
21
+ {"role": "user", "content": "Solve it carefully."}
22
+ ]
23
+ )
24
+ print(f"\nResponse: {content}")
25
+ print("\n--- Test Completed Successfully ---")
26
+ except Exception as e:
27
+ print(f"\n--- Test Failed: {type(e).__name__}: {e} ---")
28
+
29
+ if __name__ == "__main__":
30
+ asyncio.run(test_real_llm())
tests/test_solver.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
4
+
5
+ from solver.engine import GeometryEngine
6
+ from solver.models import Point, Constraint
7
+
8
+ def test_triangle_abc():
9
+ engine = GeometryEngine()
10
+
11
+ # Triangle ABC: AB=5, AC=7, angle A=60
12
+ points = [
13
+ Point(id="A"),
14
+ Point(id="B"),
15
+ Point(id="C")
16
+ ]
17
+
18
+ constraints = [
19
+ Constraint(type="length", targets=["A", "B"], value=5.0),
20
+ Constraint(type="length", targets=["A", "C"], value=7.0),
21
+ Constraint(type="angle", targets=["A"], value=60.0) # Angle at A
22
+ ]
23
+
24
+ print("Solving for Triangle ABC (AB=5, AC=7, angle A=60)...")
25
+ results = engine.solve(points, constraints)
26
+
27
+ if results:
28
+ coords = results["coordinates"]
29
+ print("Success! Coordinates:")
30
+ for pid, c in coords.items():
31
+ print(f"Point {pid}: {c}")
32
+
33
+ # Verify distance AB
34
+ dist_ab = ((coords["B"][0] - coords["A"][0])**2 + (coords["B"][1] - coords["A"][1])**2)**0.5
35
+ print(f"Verified AB distance: {dist_ab:.2f}")
36
+
37
+ # Verify distance AC
38
+ dist_ac = ((coords["C"][0] - coords["A"][0])**2 + (coords["C"][1] - coords["A"][1])**2)**0.5
39
+ print(f"Verified AC distance: {dist_ac:.2f}")
40
+ else:
41
+ print("Solver failed.")
42
+
43
+ if __name__ == "__main__":
44
+ test_triangle_abc()