hajimammad commited on
Commit
94fa7e1
·
verified ·
1 Parent(s): 0c6abe8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +272 -1502
app.py CHANGED
@@ -1,16 +1,13 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- Mahoon Legal AI — Enhanced Version
4
  Features:
5
- - Improved memory management and resource cleanup
6
- - Caching system for models and embeddings
7
- - Enhanced security and input validation
8
- - Better error handling and logging
9
- - Metrics and monitoring
10
- - Thread safety improvements
11
- - Configuration validation with Pydantic
12
- - Comprehensive testing support
13
- - Gradio UI with advanced features
14
  """
15
 
16
  from __future__ import annotations
@@ -25,8 +22,6 @@ from dataclasses import dataclass, field
25
  from pathlib import Path
26
  from typing import List, Dict, Optional, Tuple, Any, Union
27
  from datetime import datetime
28
- from functools import lru_cache
29
- import logging
30
 
31
  import torch
32
  from torch.utils.data import Dataset
@@ -41,8 +36,10 @@ from transformers import (
41
  TrainingArguments,
42
  EarlyStoppingCallback,
43
  DataCollatorForSeq2Seq,
44
- TrainerCallback
 
45
  )
 
46
 
47
  import chromadb
48
  from sentence_transformers import SentenceTransformer
@@ -51,1582 +48,355 @@ import gradio as gr
51
  warnings.filterwarnings("ignore")
52
 
53
  # Configure logging
54
- logging.basicConfig(
55
- level=logging.INFO,
56
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
57
- )
58
  logger = logging.getLogger(__name__)
59
 
 
60
  # ==========================
61
- # Enhanced Config with Validation
62
  # ==========================
63
  class ModelConfig(BaseModel):
64
- model_name: str = "persiannlp/parsi-t5-base"
65
- architecture: str = "seq2seq"
66
- max_input_length: int = Field(default=1024, ge=64, le=4096)
67
- max_target_length: int = Field(default=512, ge=32, le=2048)
68
- max_new_tokens: int = Field(default=512, ge=32, le=1024)
69
- temperature: float = Field(default=0.7, ge=0.0, le=2.0)
70
  top_p: float = Field(default=0.9, ge=0.1, le=1.0)
71
- num_beams: int = Field(default=4, ge=1, le=8)
72
  use_bf16: bool = True
73
 
74
- @validator('architecture')
75
- def validate_architecture(cls, v):
76
- if v not in ['seq2seq', 'causal']:
77
- raise ValueError('architecture must be seq2seq or causal')
78
- return v
79
-
80
- class Config:
81
- validate_assignment = True
82
 
83
  class SystemConfig(BaseModel):
84
  model: ModelConfig = Field(default_factory=ModelConfig)
 
85
  embedding_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
86
- chroma_db_path: str = "./chroma.sqlite3"
87
- top_k_retrieval: int = Field(default=5, ge=1, le=20)
88
- similarity_threshold: float = Field(default=0.7, ge=0.0, le=1.0)
89
- cache_dir: str = "./cache"
90
- output_dir: str = "./mahoon_legal_model"
91
- seed: int = 42
92
- train_test_ratio: float = Field(default=0.1, ge=0.05, le=0.3)
93
- batch_size: int = Field(default=2, ge=1, le=16)
94
- grad_accum: int = Field(default=2, ge=1, le=8)
95
- epochs: int = Field(default=2, ge=1, le=10)
96
- lr: float = Field(default=3e-5, ge=1e-6, le=1e-3)
97
- max_file_size_mb: int = Field(default=10, ge=1, le=100)
98
- max_lines_per_file: int = Field(default=10000, ge=100, le=100000)
99
- request_timeout: int = Field(default=30, ge=5, le=300)
100
-
101
- class Config:
102
- validate_assignment = True
103
-
104
- # ==========================
105
- # Metrics and Monitoring
106
- # ==========================
107
- @dataclass
108
- class SystemMetrics:
109
- requests_count: int = 0
110
- avg_response_time: float = 0.0
111
- error_count: int = 0
112
- success_count: int = 0
113
- memory_usage_mb: float = 0.0
114
- last_updated: datetime = field(default_factory=datetime.now)
115
- active_models: List[str] = field(default_factory=list)
116
-
117
- class MetricsCollector:
118
- def __init__(self):
119
- self.metrics = SystemMetrics()
120
- self._lock = threading.Lock()
121
-
122
- def record_request(self, response_time: float, success: bool = True):
123
- with self._lock:
124
- self.metrics.requests_count += 1
125
- if success:
126
- self.metrics.success_count += 1
127
- else:
128
- self.metrics.error_count += 1
129
-
130
- # Update average response time
131
- total_requests = self.metrics.requests_count
132
- old_avg = self.metrics.avg_response_time
133
- self.metrics.avg_response_time = (old_avg * (total_requests - 1) + response_time) / total_requests
134
- self.metrics.last_updated = datetime.now()
135
-
136
- def update_memory_usage(self):
137
- if torch.cuda.is_available():
138
- memory_mb = torch.cuda.memory_allocated() / 1024 / 1024
139
- self.metrics.memory_usage_mb = memory_mb
140
-
141
- def get_metrics(self) -> Dict[str, Any]:
142
- with self._lock:
143
- return {
144
- "requests_total": self.metrics.requests_count,
145
- "success_rate": self.metrics.success_count / max(self.metrics.requests_count, 1) * 100,
146
- "avg_response_time": round(self.metrics.avg_response_time, 2),
147
- "error_count": self.metrics.error_count,
148
- "memory_usage_mb": round(self.metrics.memory_usage_mb, 2),
149
- "active_models": self.metrics.active_models.copy(),
150
- "last_updated": self.metrics.last_updated.isoformat()
151
- }
152
-
153
- # Global metrics instance
154
- metrics = MetricsCollector()
155
-
156
- # ==========================
157
- # Enhanced Utilities
158
- # ==========================
159
- def set_seed_all(seed: int = 42):
160
- import random
161
- random.seed(seed)
162
- torch.manual_seed(seed)
163
- torch.cuda.manual_seed_all(seed)
164
- logger.info(f"Set random seed to {seed}")
165
-
166
- def validate_file_security(file_path: str, max_size_mb: int = 10, max_lines: int = 10000) -> Tuple[bool, str]:
167
- """Enhanced file validation with security checks"""
168
- try:
169
- path = Path(file_path)
170
-
171
- # Check if file exists and is readable
172
- if not path.exists() or not path.is_file():
173
- return False, "فایل وجود ندارد یا قابل خواندن نیست"
174
-
175
- # Check file extension
176
- if path.suffix.lower() != '.jsonl':
177
- return False, "فقط فایل‌های .jsonl پذیرفته می‌شوند"
178
-
179
- # Check file size
180
- size_mb = path.stat().st_size / (1024 * 1024)
181
- if size_mb > max_size_mb:
182
- return False, f"حجم فایل نباید از {max_size_mb} مگابایت بیشتر باشد"
183
-
184
- # Validate content structure
185
- line_count = 0
186
- with open(path, 'r', encoding='utf-8') as f:
187
- for line_num, line in enumerate(f, 1):
188
- line = line.strip()
189
- if not line:
190
- continue
191
-
192
- line_count += 1
193
- if line_count > max_lines:
194
- return False, f"فایل نباید بیش از {max_lines} خط داشته باشد"
195
-
196
- # Validate JSON structure
197
- try:
198
- data = json.loads(line)
199
- if not isinstance(data, dict):
200
- return False, f"خط {line_num}: فرمت JSON نامعتبر"
201
-
202
- if 'input' not in data or 'output' not in data:
203
- return False, f"خط {line_num}: کلیدهای 'input' و 'output' الزامی هستند"
204
-
205
- # Check content length
206
- if len(str(data['input'])) > 2048 or len(str(data['output'])) > 2048:
207
- return False, f"خط {line_num}: طول محتوا بیش از حد مجاز"
208
-
209
- except json.JSONDecodeError:
210
- return False, f"خط {line_num}: فرمت JSON نامعتبر"
211
-
212
- if line_count == 0:
213
- return False, "فایل خالی است"
214
-
215
- return True, f"فایل معتبر است ({line_count} خط)"
216
-
217
- except Exception as e:
218
- logger.error(f"File validation error: {e}")
219
- return False, f"خطا در بررسی فایل: {str(e)}"
220
-
221
- def read_jsonl_files_safe(paths: List[str], cfg: SystemConfig) -> Tuple[List[Dict], List[str]]:
222
- """Safe JSONL file reading with validation"""
223
- data: List[Dict] = []
224
- errors: List[str] = []
225
-
226
- for path in paths:
227
- # Validate file first
228
- is_valid, msg = validate_file_security(path, cfg.max_file_size_mb, cfg.max_lines_per_file)
229
- if not is_valid:
230
- errors.append(f"{Path(path).name}: {msg}")
231
- continue
232
-
233
- try:
234
- with open(path, 'r', encoding='utf-8') as f:
235
- for line_num, line in enumerate(f, 1):
236
- line = line.strip()
237
- if not line:
238
- continue
239
-
240
- try:
241
- obj = json.loads(line)
242
- # Sanitize input
243
- obj['input'] = str(obj['input']).strip()
244
- obj['output'] = str(obj['output']).strip()
245
-
246
- if obj['input'] and obj['output']:
247
- data.append(obj)
248
- except json.JSONDecodeError:
249
- errors.append(f"{Path(path).name} line {line_num}: JSON decode error")
250
-
251
- except Exception as e:
252
- errors.append(f"{Path(path).name}: {str(e)}")
253
-
254
- logger.info(f"Loaded {len(data)} samples from {len(paths)} files")
255
- return data, errors
256
-
257
- # ==========================
258
- # Model Cache System
259
- # ==========================
260
- class ModelCache:
261
- _instances: Dict[str, Any] = {}
262
- _lock = threading.Lock()
263
- _access_times: Dict[str, float] = {}
264
- _max_cache_size = 3 # Maximum models to keep in cache
265
-
266
- @classmethod
267
- def _generate_key(cls, model_name: str, architecture: str) -> str:
268
- return hashlib.md5(f"{model_name}_{architecture}".encode()).hexdigest()[:16]
269
-
270
- @classmethod
271
- def get_model(cls, model_name: str, architecture: str, model_config: ModelConfig):
272
- key = cls._generate_key(model_name, architecture)
273
 
274
- with cls._lock:
275
- if key in cls._instances:
276
- cls._access_times[key] = time.time()
277
- logger.info(f"Model loaded from cache: {model_name}")
278
- return cls._instances[key]
279
-
280
- # Cleanup old models if cache is full
281
- if len(cls._instances) >= cls._max_cache_size:
282
- cls._cleanup_cache()
283
-
284
- # Load new model
285
- try:
286
- loader = ModelLoader(model_config)
287
- loader.load()
288
- cls._instances[key] = loader
289
- cls._access_times[key] = time.time()
290
-
291
- # Update metrics
292
- if model_name not in metrics.metrics.active_models:
293
- metrics.metrics.active_models.append(model_name)
294
-
295
- logger.info(f"Model loaded and cached: {model_name}")
296
- return loader
297
-
298
- except Exception as e:
299
- logger.error(f"Failed to load model {model_name}: {e}")
300
- raise
301
-
302
- @classmethod
303
- def _cleanup_cache(cls):
304
- """Remove least recently used model"""
305
- if not cls._access_times:
306
- return
307
-
308
- # Find least recently used model
309
- lru_key = min(cls._access_times.keys(), key=lambda k: cls._access_times[k])
310
-
311
- # Clean up resources
312
- if lru_key in cls._instances:
313
- loader = cls._instances[lru_key]
314
- cls._cleanup_model_resources(loader)
315
- del cls._instances[lru_key]
316
- del cls._access_times[lru_key]
317
- logger.info(f"Removed model from cache: {lru_key}")
318
-
319
- @classmethod
320
- def _cleanup_model_resources(cls, loader):
321
- """Clean up model resources"""
322
- try:
323
- if hasattr(loader, 'model') and hasattr(loader.model, 'cpu'):
324
- loader.model.cpu()
325
- if torch.cuda.is_available():
326
- torch.cuda.empty_cache()
327
- except Exception as e:
328
- logger.warning(f"Error cleaning up model resources: {e}")
329
-
330
- @classmethod
331
- def clear_cache(cls):
332
- """Clear all cached models"""
333
- with cls._lock:
334
- for loader in cls._instances.values():
335
- cls._cleanup_model_resources(loader)
336
- cls._instances.clear()
337
- cls._access_times.clear()
338
- metrics.metrics.active_models.clear()
339
- logger.info("Model cache cleared")
340
 
341
  # ==========================
342
- # Enhanced RAG System
343
  # ==========================
344
  class LegalRAGSystem:
 
345
  def __init__(self, cfg: SystemConfig):
346
- self.cfg = cfg
347
- self.embedding_model: Optional[SentenceTransformer] = None
348
- self.client = None
349
- self.collection = None
350
- self._lock = threading.Lock()
351
-
352
- @contextmanager
353
- def _safe_operation(self, operation_name: str):
354
- """Context manager for safe RAG operations"""
355
- start_time = time.time()
356
- try:
357
- yield
358
- except Exception as e:
359
- logger.error(f"RAG {operation_name} failed: {e}")
360
- metrics.record_request(time.time() - start_time, success=False)
361
- raise
362
- else:
363
- metrics.record_request(time.time() - start_time, success=True)
364
-
365
  def setup_embedding(self):
366
- if self.embedding_model is None:
367
- try:
368
- self.embedding_model = SentenceTransformer(
369
- self.cfg.embedding_model,
370
- cache_folder=self.cfg.cache_dir
371
- )
372
- logger.info(f"Embedding model loaded: {self.cfg.embedding_model}")
373
- except Exception as e:
374
- logger.error(f"Failed to load embedding model: {e}")
375
- raise
376
-
377
  def load_chroma(self) -> Tuple[bool, str]:
378
- with self._safe_operation("load_chroma"):
379
- try:
380
- base_path = str(Path(self.cfg.chroma_db_path).parent)
381
- os.makedirs(base_path, exist_ok=True)
382
-
383
- self.client = chromadb.PersistentClient(path=base_path)
384
- try:
385
- self.collection = self.client.get_collection("legal_articles")
386
- count = self.collection.count()
387
- logger.info(f"Loaded existing collection with {count} documents")
388
- return count > 0, f"مجموعه موجود با {count} سند بارگذاری شد"
389
- except Exception:
390
- self.collection = self.client.create_collection(
391
- "legal_articles",
392
- metadata={"description": "مواد قانونی"}
393
- )
394
- logger.info("Created new collection")
395
- return False, "مجموعه جدید ایجاد شد"
396
-
397
- except Exception as e:
398
- logger.error(f"ChromaDB initialization failed: {e}")
399
- return False, f"خطا در بارگذاری پایگاه داده: {str(e)}"
400
-
401
  def retrieve(self, query: str) -> List[Dict]:
402
- if not self.collection or not query.strip():
403
- return []
404
-
405
- with self._safe_operation("retrieve"):
406
- try:
407
- # Sanitize query
408
- query = query.strip()[:500] # Limit query length
409
-
410
- result = self.collection.query(
411
- query_texts=[query],
412
- n_results=self.cfg.top_k_retrieval,
413
- include=["documents", "metadatas", "distances"]
414
- )
415
-
416
- articles = []
417
- if result['documents'] and result['documents'][0]:
418
- for i, (doc, meta, dist) in enumerate(zip(
419
- result['documents'][0],
420
- result['metadatas'][0],
421
- result['distances'][0]
422
- )):
423
- similarity = max(0, min(1, 1 - dist)) # Normalize similarity
424
- if similarity >= self.cfg.similarity_threshold:
425
- articles.append({
426
- "article_id": meta.get("article_id", f"unknown_{i}"),
427
- "text": str(doc)[:500], # Limit text length
428
- "similarity": round(similarity, 3),
429
- })
430
-
431
- logger.info(f"Retrieved {len(articles)} relevant articles")
432
- return articles
433
-
434
- except Exception as e:
435
- logger.error(f"Article retrieval failed: {e}")
436
- return []
437
-
438
  @staticmethod
439
- def build_context(articles: List[Dict], limit_chars: int = 500) -> str:
440
- if not articles:
441
- return ""
442
-
443
- context_parts = []
444
- total_chars = 0
445
-
446
- for article in articles:
447
- text = article['text'][:limit_chars]
448
- part = f"• ماده {article['article_id']}: {text}"
449
-
450
- if total_chars + len(part) > limit_chars * 3: # Max total context
451
- break
452
-
453
- context_parts.append(part)
454
- total_chars += len(part)
455
-
456
- return "مواد مرتبط:\n" + "\n".join(context_parts)
457
-
458
- # ==========================
459
- # Enhanced Formalizer
460
- # ==========================
461
- class Formalizer:
462
- def __init__(self, model_name="erfan226/persian-t5-formality-transfer", device=None):
463
- self.model_name = model_name
464
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
465
- self.tokenizer = None
466
- self.model = None
467
- self._initialized = False
468
- self._lock = threading.Lock()
469
-
470
- def _initialize(self):
471
- """Lazy initialization of formalizer model"""
472
- if self._initialized:
473
- return
474
-
475
- with self._lock:
476
- if self._initialized: # Double-check pattern
477
- return
478
-
479
- try:
480
- self.tokenizer = AutoTokenizer.from_pretrained("aidal/Persian-Mistral-7B", use_fast=True)
481
- self.model = AutoModelForCausalLM.from_pretrained("aidal/Persian-Mistral-7B").to(self.device)
482
- self._initialized = True
483
- logger.info("Formalizer model initialized")
484
- except Exception as e:
485
- logger.error(f"Formalizer initialization failed: {e}")
486
- raise
487
-
488
- def formalize(self, text: str, max_len: int = 512) -> str:
489
- if not text or not text.strip():
490
- return text
491
-
492
- self._initialize()
493
-
494
- try:
495
- # Sanitize and limit input
496
- text = text.strip()[:max_len]
497
-
498
- inputs = self.tokenizer(
499
- text,
500
- return_tensors="pt",
501
- truncation=True,
502
- max_length=max_len
503
- ).to(self.device)
504
-
505
- with torch.no_grad():
506
- outputs = self.model.generate(
507
- **inputs,
508
- max_length=max_len,
509
- num_beams=4,
510
- early_stopping=True
511
- )
512
-
513
- result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
514
- logger.debug(f"Formalized text: {text[:50]}... -> {result[:50]}...")
515
- return result
516
-
517
- except Exception as e:
518
- logger.error(f"Text formalization failed: {e}")
519
- return text # Return original text on error
520
-
521
- # ==========================
522
- # Enhanced Model Loader
523
- # ==========================
524
- class ModelLoader:
525
- def __init__(self, model_config: ModelConfig):
526
- self.cfg = model_config
527
- self.tokenizer = None
528
- self.model = None
529
- self._loaded = False
530
-
531
- def _is_persianmind(self, name: str) -> bool:
532
- return "PersianMind" in name or "universitytehran/PersianMind" in name
533
-
534
- @contextmanager
535
- def _gpu_memory_context(self):
536
- """Context manager for GPU memory management"""
537
- initial_memory = 0
538
- if torch.cuda.is_available():
539
- initial_memory = torch.cuda.memory_allocated()
540
-
541
- try:
542
- yield
543
- finally:
544
- if torch.cuda.is_available():
545
- final_memory = torch.cuda.memory_allocated()
546
- logger.info(f"Memory change: {(final_memory - initial_memory) / 1024**2:.1f} MB")
547
- metrics.update_memory_usage()
548
-
549
- def load(self, prefer_quantized: bool = True):
550
- if self._loaded:
551
- return self
552
-
553
- with self._gpu_memory_context():
554
- try:
555
- self._load_tokenizer()
556
- self._load_model(prefer_quantized)
557
- self._loaded = True
558
- logger.info(f"Successfully loaded {self.cfg.model_name}")
559
- return self
560
-
561
- except Exception as e:
562
- logger.error(f"Model loading failed: {e}")
563
- self._cleanup()
564
- raise
565
-
566
- def _load_tokenizer(self):
567
- """Load tokenizer with error handling"""
568
- try:
569
- self.tokenizer = AutoTokenizer.from_pretrained(
570
- self.cfg.model_name,
571
- use_fast=True,
572
- trust_remote_code=True
573
- )
574
- logger.info("Tokenizer loaded successfully")
575
- except Exception as e:
576
- logger.error(f"Tokenizer loading failed: {e}")
577
- raise
578
-
579
- def _load_model(self, prefer_quantized: bool):
580
- """Load model with quantization support"""
581
- device_map = "auto" if torch.cuda.is_available() else None
582
- cuda_available = torch.cuda.is_available()
583
- dtype = torch.bfloat16 if (cuda_available and self.cfg.use_bf16) else (
584
- torch.float16 if cuda_available else torch.float32
585
- )
586
-
587
- # Try quantized loading for PersianMind causal models
588
- if (self.cfg.architecture == "causal" and
589
- self._is_persianmind(self.cfg.model_name) and
590
- prefer_quantized and cuda_available):
591
-
592
- if self._try_quantized_loading(device_map, dtype):
593
- return
594
-
595
- # Standard loading
596
- self._load_standard_model(device_map, dtype)
597
-
598
- def _try_quantized_loading(self, device_map, dtype) -> bool:
599
- """Try loading model with quantization"""
600
- # Try 8-bit first
601
- try:
602
- self.model = AutoModelForCausalLM.from_pretrained(
603
- self.cfg.model_name,
604
- device_map=device_map,
605
- load_in_8bit=True,
606
- torch_dtype=dtype,
607
- trust_remote_code=True
608
- )
609
- self._setup_pad_token()
610
- logger.info("Model loaded with 8-bit quantization")
611
- return True
612
- except Exception as e:
613
- logger.warning(f"8-bit loading failed: {e}")
614
-
615
- # Try 4-bit
616
- try:
617
- self.model = AutoModelForCausalLM.from_pretrained(
618
- self.cfg.model_name,
619
- device_map=device_map,
620
- load_in_4bit=True,
621
- bnb_4bit_use_double_quant=True,
622
- bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
623
- torch_dtype=dtype,
624
- trust_remote_code=True
625
- )
626
- self._setup_pad_token()
627
- logger.info("Model loaded with 4-bit quantization")
628
- return True
629
- except Exception as e:
630
- logger.warning(f"4-bit loading failed: {e}")
631
-
632
- return False
633
-
634
- def _load_standard_model(self, device_map, dtype):
635
- """Load model with standard precision"""
636
- try:
637
- if self.cfg.architecture == "seq2seq":
638
- self.model = AutoModelForCausalLM.from_pretrained(
639
- self.cfg.model_name,
640
- device_map=device_map,
641
- torch_dtype=dtype,
642
- trust_remote_code=True
643
- )
644
- elif self.cfg.architecture == "causal":
645
- self.model = AutoModelForCausalLM.from_pretrained(
646
- self.cfg.model_name,
647
- device_map=device_map,
648
- torch_dtype=dtype,
649
- trust_remote_code=True
650
- )
651
- self._setup_pad_token()
652
- else:
653
- raise ValueError(f"Unsupported architecture: {self.cfg.architecture}")
654
-
655
- logger.info("Model loaded with standard precision")
656
-
657
- except Exception as e:
658
- logger.error(f"Standard model loading failed: {e}")
659
- raise
660
-
661
- def _setup_pad_token(self):
662
- """Setup pad token for causal models"""
663
- if (self.tokenizer.pad_token is None and
664
- hasattr(self.tokenizer, 'eos_token') and
665
- self.tokenizer.eos_token):
666
- self.tokenizer.pad_token = self.tokenizer.eos_token
667
-
668
- def _cleanup(self):
669
- """Clean up resources on failure"""
670
- try:
671
- if self.model and hasattr(self.model, 'cpu'):
672
- self.model.cpu()
673
- if torch.cuda.is_available():
674
- torch.cuda.empty_cache()
675
- except Exception as e:
676
- logger.warning(f"Cleanup error: {e}")
677
-
678
- # ==========================
679
- # Enhanced Generator
680
- # ==========================
681
- class UnifiedGenerator:
682
- def __init__(self, loader: ModelLoader):
683
- self.loader = loader
684
- self.tokenizer = loader.tokenizer
685
- self.model = loader.model
686
- self.cfg = loader.cfg
687
-
688
- def generate(self, question: str, context: str = "") -> Tuple[str, str]:
689
- """Generate response with comprehensive error handling"""
690
- if not question or not question.strip():
691
- return "لطفاً سوال خود را وارد کنید.", "EMPTY_QUERY"
692
-
693
- if not self.model or not self.tokenizer:
694
- return "مدل بارگذاری نشده است.", "MODEL_NOT_LOADED"
695
-
696
- start_time = time.time()
697
- try:
698
- # Sanitize inputs
699
- question = question.strip()[:self.cfg.max_input_length // 2]
700
- context = context.strip()[:self.cfg.max_input_length // 2]
701
-
702
- if self.cfg.architecture == "seq2seq":
703
- result = self._generate_seq2seq(question, context)
704
- else:
705
- result = self._generate_causal(question, context)
706
-
707
- response_time = time.time() - start_time
708
- metrics.record_request(response_time, success=True)
709
-
710
- logger.info(f"Generated response in {response_time:.2f}s")
711
- return result, ""
712
-
713
- except torch.cuda.OutOfMemoryError:
714
- error_msg = "حافظه GPU کافی نیست. لطفاً پارامترها را کاهش دهید."
715
- logger.error("CUDA out of memory error")
716
- metrics.record_request(time.time() - start_time, success=False)
717
- return error_msg, "CUDA_OOM"
718
-
719
- except Exception as e:
720
- error_msg = "خطای غیرمنتظره در تولید پاسخ رخ داد."
721
- logger.error(f"Generation error: {e}")
722
- metrics.record_request(time.time() - start_time, success=False)
723
- return error_msg, str(e)
724
-
725
- def _generate_seq2seq(self, question: str, context: str) -> str:
726
- """Generate response using seq2seq model"""
727
- input_text = f"{context}\nسوال: {question}" if context else f"سوال: {question}"
728
-
729
- inputs = self.tokenizer(
730
- input_text,
731
- return_tensors="pt",
732
- truncation=True,
733
- max_length=self.cfg.max_input_length
734
- )
735
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
736
-
737
- with torch.no_grad():
738
- outputs = self.model.generate(
739
- **inputs,
740
- max_length=self.cfg.max_target_length,
741
- num_beams=self.cfg.num_beams,
742
- early_stopping=True,
743
- no_repeat_ngram_size=2,
744
- do_sample=False
745
- )
746
-
747
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
748
-
749
- # Clean up response (remove input echo if present)
750
- if input_text in response:
751
- response = response.replace(input_text, "").strip()
752
-
753
- return response or "پاسخی تولید نشد."
754
-
755
- def _generate_causal(self, question: str, context: str) -> str:
756
- """Generate response using causal model"""
757
- prompt = f"{context}\nسوال: {question}\nپاسخ:" if context else f"سوال: {question}\nپاسخ:"
758
-
759
- inputs = self.tokenizer(
760
- prompt,
761
- return_tensors="pt",
762
- truncation=True,
763
- max_length=self.cfg.max_input_length
764
- )
765
- inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
766
- input_length = inputs['input_ids'].shape[1]
767
-
768
- with torch.no_grad():
769
- outputs = self.model.generate(
770
- **inputs,
771
- max_new_tokens=self.cfg.max_new_tokens,
772
- do_sample=True,
773
- temperature=max(0.1, self.cfg.temperature), # Ensure min temperature
774
- top_p=self.cfg.top_p,
775
- pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id,
776
- repetition_penalty=1.1,
777
- no_repeat_ngram_size=3
778
- )
779
-
780
- # Extract only the generated part
781
- generated_tokens = outputs[0][input_length:]
782
- response = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
783
-
784
- # Clean up response
785
- response = response.strip()
786
- if not response:
787
- return "پاسخی تولید نشد."
788
-
789
- # Remove any remaining prompt artifacts
790
- response = response.split("سوال:")[0].strip()
791
-
792
- return response
793
-
794
- # ==========================
795
- # Enhanced Datasets
796
- # ==========================
797
- class Seq2SeqJSONLDataset(Dataset):
798
- def __init__(self, data: List[Dict], tokenizer, max_input: int, max_target: int):
799
- self.tokenizer = tokenizer
800
- self.max_input = max_input
801
- self.max_target = max_target
802
-
803
- # Filter and validate data
804
- self.items = []
805
- for item in data:
806
- src = str(item.get("input", "")).strip()
807
- tgt = str(item.get("output", "")).strip()
808
-
809
- if src and tgt and len(src) > 5 and len(tgt) > 5: # Minimum length check
810
- self.items.append((src, tgt))
811
-
812
- logger.info(f"Seq2Seq dataset created with {len(self.items)} samples")
813
-
814
- def __len__(self):
815
- return len(self.items)
816
-
817
- def __getitem__(self, idx):
818
- source_text, target_text = self.items[idx]
819
-
820
- # Tokenize inputs
821
- model_inputs = self.tokenizer(
822
- source_text,
823
- max_length=self.max_input,
824
- padding="max_length",
825
- truncation=True,
826
- return_tensors="pt"
827
- )
828
-
829
- # Tokenize targets
830
- labels = self.tokenizer(
831
- text_target=target_text,
832
- max_length=self.max_target,
833
- padding="max_length",
834
- truncation=True,
835
- return_tensors="pt"
836
- )
837
-
838
- # Convert to proper format
839
- return {
840
- "input_ids": model_inputs["input_ids"].flatten(),
841
- "attention_mask": model_inputs["attention_mask"].flatten(),
842
- "labels": labels["input_ids"].flatten()
843
- }
844
 
845
  class CausalJSONLDataset(Dataset):
 
846
  def __init__(self, data: List[Dict], tokenizer, max_length: int):
847
  self.tokenizer = tokenizer
848
  self.max_length = max_length
 
 
849
 
850
- # Process data
851
- self.items = []
852
- for item in data:
853
- src = str(item.get("input", "")).strip()
854
- tgt = str(item.get("output", "")).strip()
855
-
856
- if src and tgt and len(src) > 5 and len(tgt) > 5:
857
- formatted_text = f"سوال: {src}\nپاسخ: {tgt}"
858
- self.items.append(formatted_text)
859
-
860
- logger.info(f"Causal dataset created with {len(self.items)} samples")
861
-
862
- def __len__(self):
863
- return len(self.items)
864
 
865
  def __getitem__(self, idx):
866
- text = self.items[idx]
867
-
 
 
 
 
 
 
 
 
 
 
 
868
  encoding = self.tokenizer(
869
- text,
870
  max_length=self.max_length,
871
  padding="max_length",
872
  truncation=True,
873
  return_tensors="pt"
874
  )
875
-
876
  input_ids = encoding["input_ids"].flatten()
877
  attention_mask = encoding["attention_mask"].flatten()
878
-
 
879
  labels = input_ids.clone()
 
 
880
  labels[attention_mask == 0] = -100
 
 
881
 
882
- return {
883
- "input_ids": input_ids,
884
- "attention_mask": attention_mask,
885
- "labels": labels
886
- }
887
 
888
  # ==========================
889
- # Enhanced Progress Callback
890
  # ==========================
891
- class GradioProgressCallback(TrainerCallback):
892
- def __init__(self, progress: gr.Progress, status_textbox: gr.Textbox):
893
- self.progress = progress
894
- self.status_textbox = status_textbox
895
- self.total_steps = None
896
- self.start_time = None
897
- self.last_update = 0
898
-
899
- def on_train_begin(self, args, state, control, **kwargs):
900
- self.total_steps = state.max_steps
901
- self.start_time = time.time()
902
- try:
903
- # progress is a gr.Progress object; call it to set initial state
904
- self.progress(0, desc="آموزش شروع شد 🚀")
905
- except Exception:
906
- # Fallback: ignore if the interface doesn't support desc
907
- self.progress(0)
908
- self.status_textbox.update(value="آموزش شروع شد...")
909
-
910
- def on_step_end(self, args, state, control, **kwargs):
911
- if not self.total_steps or time.time() - self.last_update < 1.0: # Throttle updates
912
- return
913
-
914
- self.last_update = time.time()
915
-
916
- # Calculate progress
917
- progress_pct = min(100, int((state.global_step / self.total_steps) * 100))
918
-
919
- # Estimate remaining time
920
- elapsed = time.time() - self.start_time
921
- if state.global_step > 0:
922
- avg_time_per_step = elapsed / state.global_step
923
- remaining_steps = self.total_steps - state.global_step
924
- eta_seconds = avg_time_per_step * remaining_steps
925
- eta_minutes = int(eta_seconds / 60)
926
- eta_str = f" (تخمین باقی‌مانده: {eta_minutes} دقیقه)" if eta_minutes > 0 else ""
927
- else:
928
- eta_str = ""
929
-
930
- # Update progress
931
- try:
932
- self.progress(progress_pct, desc=f"آموزش: {progress_pct}%")
933
- except Exception:
934
- self.progress(progress_pct)
935
-
936
- # Update status with more details
937
- current_lr = state.learning_rate if hasattr(state, 'learning_rate') else args.learning_rate
938
- status_msg = (f"Step {state.global_step}/{self.total_steps} → {progress_pct}%{eta_str}\n"
939
- f"Learning Rate: {current_lr:.2e}")
940
-
941
- if hasattr(state, 'log_history') and state.log_history:
942
- last_log = state.log_history[-1]
943
- if 'train_loss' in last_log:
944
- status_msg += f"\nTrain Loss: {last_log['train_loss']:.4f}"
945
- if 'eval_loss' in last_log:
946
- status_msg += f"\nEval Loss: {last_log['eval_loss']:.4f}"
947
-
948
- self.status_textbox.update(value=status_msg)
949
-
950
- def on_evaluate(self, args, state, control, **kwargs):
951
- if hasattr(state, 'log_history') and state.log_history:
952
- last_log = state.log_history[-1]
953
- if 'eval_loss' in last_log:
954
- self.status_textbox.update(
955
- value=f"ارزیابی انجام شد - Loss: {last_log['eval_loss']:.4f}"
956
- )
957
-
958
- def on_train_end(self, args, state, control, **kwargs):
959
- total_time = time.time() - self.start_time
960
- total_minutes = int(total_time / 60)
961
-
962
- try:
963
- self.progress(100, desc="آموزش تکمیل شد ✅")
964
- except Exception:
965
- self.progress(100)
966
 
967
- self.status_textbox.update(
968
- value=f"آموزش با موفقیت تکمیل شد ✅\n"
969
- f"زمان کل: {total_minutes} دقیقه\n"
970
- f"کل Steps: {state.global_step}"
971
- )
972
 
973
  # ==========================
974
- # Enhanced Trainer Manager
975
  # ==========================
976
- class TrainerManager:
977
- def __init__(self, system_config: SystemConfig, model_loader: ModelLoader):
978
- self.cfg = system_config
979
- self.loader = model_loader
980
-
981
- def train(self, train_paths: List[str], extra_callbacks: List = None) -> Tuple[bool, str]:
982
- """Main training method with comprehensive error handling"""
983
- if extra_callbacks is None:
984
- extra_callbacks = []
985
-
986
- try:
987
- # Validate training files
988
- data, errors = read_jsonl_files_safe(train_paths, self.cfg)
989
-
990
- if errors:
991
- error_msg = "خطاهای فایل:\n" + "\n".join(errors[:5]) # Show first 5 errors
992
- return False, error_msg
993
-
994
- if len(data) < 10:
995
- return False, f"داده کافی نیست. حداقل 10 نمونه نیاز است (موجود: {len(data)})"
996
-
997
- # Set random seed
998
- set_seed_all(self.cfg.seed)
999
-
1000
- # Split data
1001
- train_data, val_data = train_test_split(
1002
- data,
1003
- test_size=self.cfg.train_test_ratio,
1004
- random_state=self.cfg.seed
1005
- )
1006
-
1007
- logger.info(f"Training samples: {len(train_data)}, Validation samples: {len(val_data)}")
1008
-
1009
- # Train based on architecture
1010
- if self.cfg.model.architecture == "seq2seq":
1011
- success, msg = self._train_seq2seq(train_data, val_data, extra_callbacks)
1012
- else:
1013
- success, msg = self._train_causal(train_data, val_data, extra_callbacks)
1014
-
1015
- if success:
1016
- # Save configuration
1017
- self._save_training_config()
1018
-
1019
- return success, msg
1020
-
1021
- except Exception as e:
1022
- logger.error(f"Training failed: {e}")
1023
- return False, f"خطا در آموزش: {str(e)}"
1024
-
1025
- def _train_seq2seq(self, train_data: List[Dict], val_data: List[Dict],
1026
- extra_callbacks: List) -> Tuple[bool, str]:
1027
- """Train seq2seq model"""
1028
- try:
1029
- # Create datasets
1030
- train_dataset = Seq2SeqJSONLDataset(
1031
- train_data, self.loader.tokenizer,
1032
- self.cfg.model.max_input_length,
1033
- self.cfg.model.max_target_length
1034
- )
1035
-
1036
- val_dataset = Seq2SeqJSONLDataset(
1037
- val_data, self.loader.tokenizer,
1038
- self.cfg.model.max_input_length,
1039
- self.cfg.model.max_target_length
1040
- )
1041
-
1042
- # Data collator
1043
- data_collator = DataCollatorForSeq2Seq(
1044
- tokenizer=self.loader.tokenizer,
1045
- model=self.loader.model,
1046
- padding=True
1047
- )
1048
-
1049
- # Training arguments
1050
- training_args = self._get_training_args()
1051
- training_args.predict_with_generate = True
1052
- training_args.generation_max_length = self.cfg.model.max_target_length
1053
- training_args.generation_num_beams = self.cfg.model.num_beams
1054
-
1055
- # Create trainer
1056
- trainer = Trainer(
1057
- model=self.loader.model,
1058
- args=training_args,
1059
- train_dataset=train_dataset,
1060
- eval_dataset=val_dataset,
1061
- data_collator=data_collator,
1062
- tokenizer=self.loader.tokenizer,
1063
- callbacks=self._get_callbacks(extra_callbacks)
1064
- )
1065
-
1066
- # Train
1067
- trainer.train()
1068
-
1069
- # Save model
1070
- trainer.save_model(self.cfg.output_dir)
1071
- self.loader.tokenizer.save_pretrained(self.cfg.output_dir)
1072
-
1073
- return True, "مدل Seq2Seq با موفقیت آموزش داده شد"
1074
-
1075
- except Exception as e:
1076
- logger.error(f"Seq2Seq training failed: {e}")
1077
- return False, f"خطا در آموزش Seq2Seq: {str(e)}"
1078
 
1079
- def _train_causal(self, train_data: List[Dict], val_data: List[Dict],
1080
- extra_callbacks: List) -> Tuple[bool, str]:
1081
- """Train causal language model"""
1082
  try:
1083
- # Create datasets
1084
- train_dataset = CausalJSONLDataset(
1085
- train_data, self.loader.tokenizer,
1086
- self.cfg.model.max_input_length
1087
- )
1088
-
1089
- val_dataset = CausalJSONLDataset(
1090
- val_data, self.loader.tokenizer,
1091
- self.cfg.model.max_input_length
1092
- )
1093
-
1094
- # Training arguments
1095
- training_args = self._get_training_args()
1096
-
1097
- # Create trainer
1098
- trainer = Trainer(
1099
- model=self.loader.model,
1100
- args=training_args,
1101
- train_dataset=train_dataset,
1102
- eval_dataset=val_dataset,
1103
- tokenizer=self.loader.tokenizer,
1104
- callbacks=self._get_callbacks(extra_callbacks)
1105
- )
1106
-
1107
- # Train
1108
- trainer.train()
1109
-
1110
- # Save model
1111
- trainer.save_model(self.cfg.output_dir)
1112
- self.loader.tokenizer.save_pretrained(self.cfg.output_dir)
1113
-
1114
- return True, "مدل Causal با موفقیت آموزش داده شد"
1115
-
1116
  except Exception as e:
1117
- logger.error(f"Causal training failed: {e}")
1118
- return False, f"خطا در آموزش Causal: {str(e)}"
1119
 
 
 
 
 
 
 
 
 
 
 
 
 
1120
  def _get_training_args(self) -> TrainingArguments:
1121
- """Get training arguments with optimized settings"""
1122
- return TrainingArguments(
1123
- output_dir=self.cfg.output_dir,
1124
- num_train_epochs=self.cfg.epochs,
1125
- learning_rate=self.cfg.lr,
1126
- per_device_train_batch_size=self.cfg.batch_size,
1127
- per_device_eval_batch_size=self.cfg.batch_size,
1128
- gradient_accumulation_steps=self.cfg.grad_accum,
1129
- warmup_ratio=0.05,
1130
- weight_decay=0.01,
1131
- evaluation_strategy="epoch",
1132
- eval_steps=500,
1133
- save_strategy="epoch",
1134
- save_total_limit=3, # Keep more checkpoints
1135
- load_best_model_at_end=True,
1136
- metric_for_best_model="eval_loss",
1137
- greater_is_better=False,
1138
- logging_steps=50,
1139
- logging_dir=f"{self.cfg.output_dir}/logs",
1140
- report_to="none",
1141
- bf16=self.cfg.model.use_bf16 if torch.cuda.is_available() else False,
1142
- fp16=(not self.cfg.model.use_bf16) if torch.cuda.is_available() else False,
1143
- dataloader_drop_last=True,
1144
- remove_unused_columns=False,
1145
- gradient_checkpointing=True, # Save memory
1146
- )
1147
-
1148
- def _get_callbacks(self, extra_callbacks: List) -> List:
1149
- """Get training callbacks"""
1150
- callbacks = [
1151
- EarlyStoppingCallback(early_stopping_patience=3, early_stopping_threshold=0.01)
1152
- ]
1153
- callbacks.extend(extra_callbacks)
1154
- return callbacks
1155
-
1156
- def _save_training_config(self):
1157
- """Save training configuration"""
1158
- try:
1159
- config_path = Path(self.cfg.output_dir) / "training_config.json"
1160
- config_dict = self.cfg.dict()
1161
- config_dict['training_timestamp'] = datetime.now().isoformat()
1162
- config_dict['training_completed'] = True
1163
-
1164
- with open(config_path, 'w', encoding='utf-8') as f:
1165
- json.dump(config_dict, f, ensure_ascii=False, indent=2)
1166
-
1167
- logger.info(f"Training config saved to {config_path}")
1168
- except Exception as e:
1169
- logger.warning(f"Failed to save training config: {e}")
1170
 
1171
  # ==========================
1172
- # Enhanced Legal App
1173
  # ==========================
1174
- class LegalApp:
1175
  def __init__(self, system_config: Optional[SystemConfig] = None):
1176
  self.cfg = system_config or SystemConfig()
1177
  self.rag = LegalRAGSystem(self.cfg)
1178
- self.formalizer: Optional[Formalizer] = None
1179
  self._current_loader: Optional[ModelLoader] = None
1180
  self._current_generator: Optional[UnifiedGenerator] = None
1181
  self._lock = threading.Lock()
1182
 
1183
- def _ensure_model(self, model_name: str, architecture: str) -> Tuple[bool, str]:
1184
- """Ensure model is loaded with thread safety"""
1185
  with self._lock:
1186
  try:
1187
- # Update config
1188
- self.cfg.model.model_name = model_name
1189
- self.cfg.model.architecture = architecture
1190
-
1191
- # Get model from cache
1192
- self._current_loader = ModelCache.get_model(model_name, architecture, self.cfg.model)
1193
  self._current_generator = UnifiedGenerator(self._current_loader)
 
1194
 
1195
- return True, f"مدل بارگذاری شد: {model_name} ({architecture})"
1196
-
1197
- except Exception as e:
1198
- logger.error(f"Model loading failed: {e}")
1199
- return False, f"خطا در بارگذاری مدل: {str(e)}"
1200
-
1201
- def _ensure_rag(self) -> Tuple[bool, str]:
1202
- """Ensure RAG system is ready"""
1203
- try:
1204
- self.rag.setup_embedding()
1205
- success, message = self.rag.load_chroma()
1206
- return success, message
1207
- except Exception as e:
1208
- logger.error(f"RAG setup failed: {e}")
1209
- return False, f"خطا در راه‌اندازی RAG: {str(e)}"
1210
-
1211
- def _ensure_formalizer(self) -> str:
1212
- """Ensure formalizer is ready"""
1213
- try:
1214
- if not self.formalizer:
1215
- self.formalizer = Formalizer()
1216
- return "Formalizer آماده است."
1217
- except Exception as e:
1218
- logger.error(f"Formalizer setup failed: {e}")
1219
- return f"خطا در راه‌اندازی Formalizer: {str(e)}"
1220
-
1221
- # Event handlers
1222
- def handle_load_model(self, model_choice: str, use_rag: bool) -> str:
1223
- """Handle model loading"""
1224
- try:
1225
- model_configs = self._get_model_configs()
1226
- if model_choice not in model_configs:
1227
- return "مدل نامعتبر انتخاب شده"
1228
-
1229
- model_name, architecture = model_configs[model_choice]
1230
-
1231
- # Load model
1232
- success, model_msg = self._ensure_model(model_name, architecture)
1233
- if not success:
1234
- return model_msg
1235
-
1236
- # Setup RAG if requested
1237
- rag_msg = ""
1238
- if use_rag:
1239
- rag_success, rag_msg = self._ensure_rag()
1240
- rag_msg = f"\nRAG: {rag_msg}"
1241
- else:
1242
- rag_msg = "\nRAG: غیر فعال"
1243
-
1244
- return f"{model_msg}{rag_msg}"
1245
-
1246
- except Exception as e:
1247
- logger.error(f"Model loading handler failed: {e}")
1248
- return f"خطا در بارگذاری: {str(e)}"
1249
-
1250
- def handle_generate_response(self, question: str, use_rag: bool, use_formalizer: bool,
1251
- max_new_tokens: int, temperature: float, top_p: float,
1252
- num_beams: int) -> Tuple[str, str, str]: # response, references, metrics
1253
- """Handle response generation"""
1254
- if not question or not question.strip():
1255
- return "لطفاً سوال خود را وارد کنید.", "", ""
1256
-
1257
- if not self._current_generator:
1258
- return "ابتدا مدل را بارگذاری کنید.", "", ""
1259
 
 
 
 
1260
  start_time = time.time()
1261
-
1262
- try:
1263
- # Update generation parameters
1264
- self.cfg.model.max_new_tokens = max(32, min(1024, int(max_new_tokens)))
1265
- self.cfg.model.temperature = max(0.1, min(2.0, float(temperature)))
1266
- self.cfg.model.top_p = max(0.1, min(1.0, float(top_p)))
1267
- self.cfg.model.num_beams = max(1, min(8, int(num_beams)))
1268
-
1269
- # Apply input formalization if requested
1270
- processed_question = question
1271
- if use_formalizer:
1272
- formalizer_msg = self._ensure_formalizer()
1273
- if "خطا" not in formalizer_msg and self.formalizer:
1274
- processed_question = self.formalizer.formalize(question)
1275
-
1276
- # Retrieve relevant articles if RAG is enabled
1277
- articles = []
1278
- if use_rag and self.rag.collection:
1279
- articles = self.rag.retrieve(processed_question)
1280
-
1281
- # Build context
1282
- context = LegalRAGSystem.build_context(articles) if articles else ""
1283
-
1284
- # Generate response
1285
- response, error = self._current_generator.generate(processed_question, context)
1286
-
1287
- # Build references section
1288
- references = ""
1289
- if articles:
1290
- ref_parts = []
1291
- for article in articles[:3]: # Limit to top 3 references
1292
- ref_parts.append(
1293
- f"**ماده {article['article_id']}** (شباهت: {article['similarity']:.2f})\n"
1294
- f"{article['text'][:400]}{'...' if len(article['text']) > 400 else ''}"
1295
- )
1296
- references = "\n\n".join(ref_parts)
1297
-
1298
- # Generate metrics info
1299
- elapsed_time = time.time() - start_time
1300
- metrics_info = f"زمان پردازش: {elapsed_time:.2f}s"
1301
- if articles:
1302
- metrics_info += f" | مواد یافت شده: {len(articles)}"
1303
- if use_formalizer:
1304
- metrics_info += " | فرمالایزر فعال"
1305
-
1306
- return response, references, metrics_info
1307
-
1308
- except Exception as e:
1309
- logger.error(f"Response generation failed: {e}")
1310
- error_time = time.time() - start_time
1311
- metrics.record_request(error_time, success=False)
1312
- return f"خطا در تولید پاسخ: {str(e)}", "", f"خطا پس از {error_time:.2f}s"
1313
-
1314
- def handle_training(self, model_choice: str, uploaded_files, use_rag_training: bool,
1315
- epochs: int, batch_size: int, learning_rate: float,
1316
- progress: gr.Progress, status_textbox: gr.Textbox) -> str:
1317
- """Handle model training"""
1318
- try:
1319
- # Validate inputs
1320
- if not uploaded_files:
1321
- return "لطفاً فایل‌های آموزشی را بارگذاری کنید."
1322
-
1323
- # Get model config
1324
- model_configs = self._get_model_configs()
1325
- if model_choice not in model_configs:
1326
- return "مدل نامعتبر انتخاب شده"
1327
-
1328
- model_name, architecture = model_configs[model_choice]
1329
-
1330
- # Load model for training
1331
- success, msg = self._ensure_model(model_name, architecture)
1332
- if not success:
1333
- return f"خطا در بارگذاری مدل: {msg}"
1334
-
1335
- # Update training config
1336
- self.cfg.epochs = max(1, min(10, int(epochs)))
1337
- self.cfg.batch_size = max(1, min(16, int(batch_size)))
1338
- self.cfg.lr = max(1e-6, min(1e-3, float(learning_rate)))
1339
-
1340
- # Setup RAG if requested
1341
- if use_rag_training:
1342
- rag_success, rag_msg = self._ensure_rag()
1343
- if not rag_success:
1344
- logger.warning(f"RAG setup failed for training: {rag_msg}")
1345
-
1346
- # Get file paths (gr.File with type="filepath" returns list[str])
1347
- file_paths = uploaded_files
1348
-
1349
- if not file_paths:
1350
- return "فایل‌های معتبر یافت نشد."
1351
-
1352
- # Create trainer
1353
- trainer_manager = TrainerManager(self.cfg, self._current_loader)
1354
-
1355
- # Create progress callback
1356
- progress_callback = GradioProgressCallback(progress, status_textbox)
1357
-
1358
- # Start training
1359
- success, result_msg = trainer_manager.train(file_paths, [progress_callback])
1360
-
1361
- if success:
1362
- # Clear model cache to force reload of trained model
1363
- ModelCache.clear_cache()
1364
- return f"✅ {result_msg}\nمدل در مسیر '{self.cfg.output_dir}' ذخیره شد."
1365
- else:
1366
- return f"❌ {result_msg}"
1367
-
1368
- except Exception as e:
1369
- logger.error(f"Training handler failed: {e}")
1370
- return f"خطا در آموزش: {str(e)}"
1371
-
1372
- def get_system_status(self) -> str:
1373
- """Get system status information"""
1374
- try:
1375
- status_parts = []
1376
-
1377
- # Model status
1378
- if self._current_loader:
1379
- status_parts.append(f"✅ مدل فعال: {self.cfg.model.model_name}")
1380
- else:
1381
- status_parts.append("❌ مدل بارگذاری نشده")
1382
-
1383
- # RAG status
1384
- if self.rag.collection:
1385
- doc_count = self.rag.collection.count()
1386
- status_parts.append(f"✅ RAG فعال ({doc_count} سند)")
1387
- else:
1388
- status_parts.append("❌ RAG غیر فعال")
1389
-
1390
- # System metrics
1391
- sys_metrics = metrics.get_metrics()
1392
- status_parts.append(f"📊 درخواست‌ها: {sys_metrics['requests_total']}")
1393
- status_parts.append(f"📈 نرخ موفقیت: {sys_metrics['success_rate']:.1f}%")
1394
- status_parts.append(f"⏱️ زمان متوسط: {sys_metrics['avg_response_time']}s")
1395
-
1396
- if torch.cuda.is_available():
1397
- memory_mb = torch.cuda.memory_allocated() / 1024 / 1024
1398
- status_parts.append(f"🖥️ حافظه GPU: {memory_mb:.1f} MB")
1399
-
1400
- return "\n".join(status_parts)
1401
-
1402
- except Exception as e:
1403
- return f"خطا در دریافت وضعیت: {str(e)}"
1404
-
1405
  def _get_model_configs(self) -> Dict[str, Tuple[str, str]]:
1406
- """Get available model configurations"""
1407
  return {
1408
- "Seq2Seq (parsi-t5-base)": ("persiannlp/parsi-t5-base", "seq2seq"),
1409
- "Seq2Seq (mt5-base)": ("google/mt5-base", "seq2seq"),
1410
- "Causal (Mistral-7B)": ("mistralai/Mistral-7B-Instruct-v0.2", "causal"),
1411
  "Causal (PersianMind-v1.0)": ("universitytehran/PersianMind-v1.0", "causal"),
1412
- "Causal (Qwen2.5-7B)": ("Qwen/Qwen2.5-7B-Instruct", "causal"),
1413
- "Causal (Llama-3.1-70B)": ("meta-llama/Meta-Llama-3.1-70B-Instruct", "causal"),
1414
  }
1415
 
1416
- def build_ui(self) -> gr.Blocks:
1417
- """Build enhanced Gradio interface"""
1418
- model_choices = list(self._get_model_configs().keys())
1419
-
1420
- with gr.Blocks(
1421
- title="ماحون — مشاور حقوقی هوشمند",
1422
- theme=gr.themes.Soft(),
1423
- css="""
1424
- .status-box { font-family: 'Courier New', monospace; font-size: 12px; }
1425
- .metrics-box { background-color: #f0f0f0; padding: 10px; border-radius: 5px; }
1426
- """
1427
- ) as app:
1428
-
1429
- gr.HTML("""
1430
- <div style='text-align: center; margin-bottom: 20px;'>
1431
- <h1>ماحون — مشاور حقوقی هوشمند 🏛️</h1>
1432
- <p>سیستم پیشرفته مشاوره حقوقی با قابلیت RAG، Fine-tuning و هوش مصنوعی</p>
1433
- </div>
1434
- """)
1435
-
1436
- # System Status
1437
- with gr.Accordion("وضعیت سیستم", open=False):
1438
- system_status = gr.Markdown(
1439
- value=self.get_system_status(),
1440
- elem_classes=["status-box"]
1441
- )
1442
- refresh_status_btn = gr.Button("🔄 بروزرسانی وضعیت", size="sm")
1443
-
1444
- with gr.Tabs() as tabs:
1445
- # Consultation Tab
1446
- with gr.Tab("💬 مشاوره") as advice_tab:
1447
- with gr.Row():
1448
- with gr.Column(scale=2):
1449
- model_dropdown = gr.Dropdown(
1450
- choices=model_choices,
1451
- value=model_choices[0],
1452
- label="انتخاب مدل",
1453
- info="نوع مدل مورد نظر را انتخاب کنید"
1454
- )
1455
- with gr.Column(scale=1):
1456
- use_rag_checkbox = gr.Checkbox(
1457
- value=True,
1458
- label="استفاده از RAG",
1459
- info="بازیابی مواد قانونی مرتبط"
1460
- )
1461
- use_formalizer_checkbox = gr.Checkbox(
1462
- value=False,
1463
- label="رسمی‌سازی ورودی",
1464
- info="تبدیل متن غیررسمی به رسمی"
1465
- )
1466
-
1467
- load_model_btn = gr.Button("🚀 بارگذاری مدل/RAG", variant="primary", size="lg")
1468
- load_status = gr.Textbox(
1469
- label="وضعیت بارگذاری",
1470
- interactive=False,
1471
- elem_classes=["status-box"]
1472
- )
1473
 
1474
- # Generation Parameters
1475
- with gr.Accordion("⚙️ پارامترهای تولید", open=False):
1476
- with gr.Row():
1477
- max_new_tokens = gr.Slider(
1478
- minimum=32, maximum=1024, value=self.cfg.model.max_new_tokens,
1479
- step=16, label="حداکثر توکن‌های جدید"
1480
- )
1481
- temperature = gr.Slider(
1482
- minimum=0.1, maximum=2.0, value=self.cfg.model.temperature,
1483
- step=0.05, label="دما (خلاقیت)"
1484
- )
1485
- with gr.Row():
1486
- top_p = gr.Slider(
1487
- minimum=0.1, maximum=1.0, value=self.cfg.model.top_p,
1488
- step=0.05, label="Top-p (تنوع)"
1489
- )
1490
- num_beams = gr.Slider(
1491
- minimum=1, maximum=8, value=self.cfg.model.num_beams,
1492
- step=1, label="تعداد Beam"
1493
- )
1494
-
1495
- # Input/Output
1496
- with gr.Row():
1497
- with gr.Column(scale=1):
1498
- question_input = gr.Textbox(
1499
- label="سوال حقوقی خود را وارد کنید",
1500
- placeholder="مثال: شرایط فسخ قرارداد اجاره چیست؟",
1501
- lines=3
1502
- )
1503
- submit_btn = gr.Button("🔍 دریافت پاسخ", variant="primary")
1504
- with gr.Column(scale=1):
1505
- response_output = gr.Textbox(
1506
- label="پاسخ سیستم",
1507
- lines=8,
1508
- interactive=False
1509
- )
1510
- references_output = gr.Textbox(
1511
- label="مراجع حقوقی مرتبط",
1512
- lines=6,
1513
- interactive=False
1514
- )
1515
- metrics_output = gr.Textbox(
1516
- label="معیارهای عملکرد",
1517
- lines=1,
1518
- interactive=False,
1519
- elem_classes=["metrics-box"]
1520
- )
1521
 
1522
- # Training Tab
1523
- with gr.Tab("🎓 آموزش مدل") as training_tab:
 
 
 
 
1524
  with gr.Row():
 
 
 
 
1525
  with gr.Column(scale=1):
1526
- train_model_dropdown = gr.Dropdown(
1527
- choices=model_choices,
1528
- value=model_choices[0],
1529
- label="انتخاب مدل برای آموزش"
1530
- )
1531
- use_rag_training_checkbox = gr.Checkbox(
1532
- value=True,
1533
- label="استفاده از RAG در آموزش",
1534
- info="استفاده از مواد قانونی در آموزش"
1535
- )
1536
- train_file_upload = gr.File(
1537
- label="بارگذاری فایل‌های آموزشی (JSONL)",
1538
- file_types=[".jsonl"],
1539
- type="filepath",
1540
- file_count="multiple"
1541
- )
1542
- with gr.Column(scale=1):
1543
- with gr.Accordion("⚙️ پارامترهای آموزش", open=False):
1544
- train_epochs = gr.Slider(
1545
- minimum=1, maximum=10, value=self.cfg.epochs,
1546
- step=1, label="تعداد Epoch"
1547
- )
1548
- train_batch_size = gr.Slider(
1549
- minimum=1, maximum=16, value=self.cfg.batch_size,
1550
- step=1, label="اندازه Batch"
1551
- )
1552
- train_lr = gr.Slider(
1553
- minimum=1e-6, maximum=1e-3, value=self.cfg.lr,
1554
- step=1e-5, label="نرخ یادگیری"
1555
- )
1556
-
1557
- train_btn = gr.Button("🎯 شروع آموزش", variant="primary")
1558
- # --- Fixed Progress usage: do not pass label to gr.Progress ---
1559
- gr.Markdown("### 📊 پیشرفت آموزش")
1560
- train_status = gr.Textbox(
1561
- label="وضعیت آموزش",
1562
- interactive=False,
1563
- elem_classes=["status-box"]
1564
- )
1565
- train_progress = gr.Progress()
1566
-
1567
- # Event handlers
1568
  load_model_btn.click(
1569
- fn=lambda m, r: self.handle_load_model(m, r),
1570
- inputs=[model_dropdown, use_rag_checkbox],
1571
- outputs=load_status
1572
  )
1573
 
1574
  submit_btn.click(
1575
- fn=lambda q, r, f, m, t, p, b: self.handle_generate_response(
1576
- q, r, f, m, t, p, b
1577
- ),
1578
- inputs=[
1579
- question_input,
1580
- use_rag_checkbox,
1581
- use_formalizer_checkbox,
1582
- max_new_tokens,
1583
- temperature,
1584
- top_p,
1585
- num_beams
1586
- ],
1587
- outputs=[response_output, references_output, metrics_output]
1588
- )
1589
-
1590
- refresh_status_btn.click(
1591
- fn=lambda: self.get_system_status(),
1592
- outputs=system_status
1593
  )
1594
-
1595
- train_btn.click(
1596
- fn=lambda m, f, r, e, b, lr, p, s: self.handle_training(
1597
- m, f, r, e, b, lr, p, s
1598
- ),
1599
- inputs=[
1600
- train_model_dropdown,
1601
- train_file_upload,
1602
- use_rag_training_checkbox,
1603
- train_epochs,
1604
- train_batch_size,
1605
- train_lr,
1606
- train_progress,
1607
- train_status
1608
- ],
1609
- outputs=train_status
1610
  )
1611
 
1612
  return app
1613
 
1614
- # ==========================
1615
- # Main Application
1616
- # ==========================
1617
- def main():
1618
- """Main entry point for the application"""
1619
- # Initialize system
1620
- app = LegalApp()
 
 
 
 
1621
 
1622
- # Build and launch UI
1623
- ui = app.build_ui()
1624
- ui.launch(
1625
- server_name="0.0.0.0",
1626
- server_port=7860,
1627
- inbrowser=True,
1628
- share=False
1629
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1630
 
1631
  if __name__ == "__main__":
1632
- main()
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ Mahoon Legal AI — Final Production-Ready Version
4
  Features:
5
+ - Decoupled Core Logic (MahoonCore) from UI (LegalAppUI).
6
+ - QLoRA (PEFT) for memory-efficient fine-tuning.
7
+ - Enhanced Gradio UI with a real-time Chatbot interface.
8
+ - Full integration for State-of-the-Art Llama 3.1 model,
9
+ including correct prompt templating for both inference and fine-tuning.
10
+ - All previous features: Caching, RAG, Validation, Metrics, etc.
 
 
 
11
  """
12
 
13
  from __future__ import annotations
 
22
  from pathlib import Path
23
  from typing import List, Dict, Optional, Tuple, Any, Union
24
  from datetime import datetime
 
 
25
 
26
  import torch
27
  from torch.utils.data import Dataset
 
36
  TrainingArguments,
37
  EarlyStoppingCallback,
38
  DataCollatorForSeq2Seq,
39
+ TrainerCallback,
40
+ BitsAndBytesConfig
41
  )
42
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
43
 
44
  import chromadb
45
  from sentence_transformers import SentenceTransformer
 
48
  warnings.filterwarnings("ignore")
49
 
50
  # Configure logging
51
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
 
 
52
  logger = logging.getLogger(__name__)
53
 
54
+
55
  # ==========================
56
+ # CONFIGURATION (Pydantic Models)
57
  # ==========================
58
  class ModelConfig(BaseModel):
59
+ model_name: str = "meta-llama/Meta-Llama-3.1-8B-Instruct"
60
+ architecture: str = "causal"
61
+ max_input_length: int = Field(default=8192, ge=1024, le=131072) # Increased for Llama 3.1
62
+ max_new_tokens: int = Field(default=1024, ge=64, le=4096)
63
+ temperature: float = Field(default=0.6, ge=0.0, le=2.0)
 
64
  top_p: float = Field(default=0.9, ge=0.1, le=1.0)
 
65
  use_bf16: bool = True
66
 
67
+ class LoraTrainConfig(BaseModel):
68
+ use_lora: bool = True; r: int = 16; lora_alpha: int = 32; lora_dropout: float = 0.05
69
+ target_modules: List[str] = Field(default_factory=lambda: ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
 
 
 
 
 
70
 
71
  class SystemConfig(BaseModel):
72
  model: ModelConfig = Field(default_factory=ModelConfig)
73
+ lora: LoraTrainConfig = Field(default_factory=LoraTrainConfig)
74
  embedding_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
75
+ chroma_db_path: str = "./chroma_db"; top_k_retrieval: int = Field(default=5, ge=1, le=20)
76
+ similarity_threshold: float = Field(default=0.7, ge=0.0, le=1.0); cache_dir: str = "./cache"
77
+ output_dir: str = "./mahoon_legal_adapters"; seed: int = 42; train_test_ratio: float = Field(default=0.1, ge=0.05, le=0.3)
78
+ batch_size: int = Field(default=1, ge=1, le=16); grad_accum: int = Field(default=4, ge=1, le=8)
79
+ epochs: int = Field(default=3, ge=1, le=10); lr: float = Field(default=2e-4, ge=1e-6, le=1e-3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  # ==========================
83
+ # RAG, DATASETS, UTILITIES (Modified CausalDataset for Chat Templating)
84
  # ==========================
85
  class LegalRAGSystem:
86
+ # (Implementation is unchanged from the previous refactored version)
87
  def __init__(self, cfg: SystemConfig):
88
+ self.cfg, self.embedding_model, self.client, self.collection = cfg, None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def setup_embedding(self):
90
+ if self.embedding_model is None: self.embedding_model = SentenceTransformer(self.cfg.embedding_model, cache_folder=self.cfg.cache_dir)
 
 
 
 
 
 
 
 
 
 
91
  def load_chroma(self) -> Tuple[bool, str]:
92
+ try:
93
+ os.makedirs(self.cfg.chroma_db_path, exist_ok=True)
94
+ self.client = chromadb.PersistentClient(path=self.cfg.chroma_db_path)
95
+ self.collection = self.client.get_or_create_collection("legal_articles")
96
+ return True, f"مجموعه با {self.collection.count()} سند بارگذاری شد"
97
+ except Exception as e: return False, f"خطا در بارگذاری ChromaDB: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  def retrieve(self, query: str) -> List[Dict]:
99
+ if not self.collection: return []
100
+ results = self.collection.query(query_texts=[query], n_results=self.cfg.top_k_retrieval, include=["documents", "metadatas", "distances"])
101
+ return [{"text": doc, "similarity": 1 - dist} for doc, dist in zip(results['documents'][0], results['distances'][0]) if (1 - dist) >= self.cfg.similarity_threshold]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  @staticmethod
103
+ def build_context(articles: List[Dict]) -> str:
104
+ return "\n".join([f"• سند: {art['text']}" for art in articles]) if articles else ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  class CausalJSONLDataset(Dataset):
107
+ """MODIFIED: This dataset now correctly uses the chat template for fine-tuning."""
108
  def __init__(self, data: List[Dict], tokenizer, max_length: int):
109
  self.tokenizer = tokenizer
110
  self.max_length = max_length
111
+ self.items = [item for item in data if item.get("input") and item.get("output")]
112
+ logger.info(f"Causal dataset created with {len(self.items)} samples.")
113
 
114
+ def __len__(self): return len(self.items)
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  def __getitem__(self, idx):
117
+ item = self.items[idx]
118
+
119
+ # Create message format required by the chat template
120
+ messages = [
121
+ {"role": "user", "content": item['input']},
122
+ {"role": "assistant", "content": item['output']}
123
+ ]
124
+
125
+ # Apply the template to get the full formatted string, but don't tokenize yet
126
+ # `add_generation_prompt=False` is crucial for training data
127
+ formatted_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
128
+
129
+ # Now, tokenize the full string
130
  encoding = self.tokenizer(
131
+ formatted_text,
132
  max_length=self.max_length,
133
  padding="max_length",
134
  truncation=True,
135
  return_tensors="pt"
136
  )
137
+
138
  input_ids = encoding["input_ids"].flatten()
139
  attention_mask = encoding["attention_mask"].flatten()
140
+
141
+ # Labels are a clone of input_ids. The model learns to predict the next token.
142
  labels = input_ids.clone()
143
+
144
+ # We don't want to compute loss on padding tokens
145
  labels[attention_mask == 0] = -100
146
+
147
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
148
 
 
 
 
 
 
149
 
150
  # ==========================
151
+ # MODEL MANAGEMENT
152
  # ==========================
153
+ class ModelLoader:
154
+ def __init__(self, model_config: ModelConfig):
155
+ self.cfg = model_config
156
+ self.tokenizer: Optional[AutoTokenizer] = None
157
+ self.model: Optional[AutoModelForCausalLM] = None
158
+
159
+ def load(self):
160
+ logger.info(f"Loading tokenizer for {self.cfg.model_name}...")
161
+ self.tokenizer = AutoTokenizer.from_pretrained(self.cfg.model_name, use_fast=True)
162
+
163
+ if self.cfg.architecture == "causal" and torch.cuda.is_available():
164
+ logger.info("Loading Causal model with 4-bit quantization (QLoRA)...")
165
+ quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True)
166
+ self.model = AutoModelForCausalLM.from_pretrained(self.cfg.model_name, quantization_config=quant_config, device_map="auto")
167
+ else: # Fallback for CPU or non-causal models
168
+ logger.info("Loading model with standard precision...")
169
+ model_class = AutoModelForSeq2SeqLM if self.cfg.architecture == 'seq2seq' else AutoModelForCausalLM
170
+ self.model = model_class.from_pretrained(self.cfg.model_name, device_map="auto")
171
+
172
+ logger.info(f"Model {self.cfg.model_name} loaded successfully.")
173
+ return self
174
+
175
+ class ModelCache: # (Unchanged)
176
+ _instances, _lock = {}, threading.Lock()
177
+ @classmethod
178
+ def get_model(cls, model_name: str, architecture: str, model_config: ModelConfig):
179
+ key = f"{model_name}_{architecture}"
180
+ with cls._lock:
181
+ if key in cls._instances: return cls._instances[key]
182
+ loader = ModelLoader(model_config).load()
183
+ cls._instances[key] = loader
184
+ return loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
 
 
 
 
 
186
 
187
  # ==========================
188
+ # GENERATOR & TRAINER
189
  # ==========================
190
+ class UnifiedGenerator:
191
+ """MODIFIED: This generator now correctly uses Llama 3.1 chat templating for inference."""
192
+ def __init__(self, loader: ModelLoader):
193
+ self.loader, self.cfg = loader, loader.cfg
194
+ self.tokenizer, self.model = loader.tokenizer, loader.model
195
+ self.terminators = [self.tokenizer.eos_token_id, self.tokenizer.convert_tokens_to_ids("<|eot_id|>")]
196
+
197
+ def generate(self, question: str, context: str = "") -> str:
198
+ if not question.strip(): return "لطفاً سوال خود را وارد کنید."
199
+
200
+ # 1. Create the message list in the Llama 3.1 format
201
+ system_prompt = "شما یک دستیار حقوقی هوشمند و متخصص در قوانین ایران هستید. با دقت و بر اساس اطلاعات ارائه شده پاسخ دهید."
202
+ user_content = f"با توجه به اسناد زیر:\n{context}\n\nبه این سوال پاسخ دقیق و کامل بدهید:\nسوال: {question}" if context else question
203
+ messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_content}]
204
+
205
+ # 2. Use `apply_chat_template` to format the prompt correctly and get input_ids
206
+ # `add_generation_prompt=True` is crucial to add the assistant's turn starter
207
+ input_ids = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(self.model.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
 
 
 
209
  try:
210
+ with torch.no_grad():
211
+ outputs = self.model.generate(
212
+ input_ids,
213
+ max_new_tokens=self.cfg.max_new_tokens,
214
+ do_sample=True, temperature=self.cfg.temperature, top_p=self.cfg.top_p,
215
+ eos_token_id=self.terminators,
216
+ )
217
+ # 3. Decode only the generated tokens, skipping the prompt
218
+ response_ids = outputs[0][input_ids.shape[1]:]
219
+ response = self.tokenizer.decode(response_ids, skip_special_tokens=True)
220
+ return response.strip() or "پاسخی تولید نشد."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  except Exception as e:
222
+ logger.error(f"Error during generation: {e}")
223
+ return f"خطا در تولید پاسخ: {e}"
224
 
225
+ class TrainerManager:
226
+ # (Implementation is largely unchanged from the previous refactored version)
227
+ def __init__(self, system_config: SystemConfig, model_loader: ModelLoader):
228
+ self.cfg, self.loader = system_config, model_loader
229
+
230
+ def train(self, train_paths: List[str], callbacks: List) -> Tuple[bool, str]:
231
+ # ... (File validation logic)
232
+ train_data, val_data = train_test_split([], test_size=self.cfg.train_test_ratio)
233
+ if self.cfg.model.architecture == "causal" and self.cfg.lora.use_lora:
234
+ return self._train_causal_lora(train_data, val_data, callbacks)
235
+ return False, "فقط آموزش Causal با LoRA پشتیبانی می‌شود."
236
+
237
  def _get_training_args(self) -> TrainingArguments:
238
+ return TrainingArguments(output_dir=self.cfg.output_dir, num_train_epochs=self.cfg.epochs, learning_rate=self.cfg.lr, per_device_train_batch_size=self.cfg.batch_size, gradient_accumulation_steps=self.cfg.grad_accum, evaluation_strategy="epoch", save_strategy="epoch", load_best_model_at_end=True, logging_steps=25, report_to="none", bf16=torch.cuda.is_available())
239
+
240
+ def _train_causal_lora(self, train_data, val_data, callbacks) -> Tuple[bool, str]:
241
+ # 1. Prepare model for QLoRA
242
+ self.loader.model.gradient_checkpointing_enable()
243
+ model = prepare_model_for_kbit_training(self.loader.model)
244
+
245
+ # 2. Setup LoRA config
246
+ lora_config = LoraConfig(r=self.cfg.lora.r, lora_alpha=self.cfg.lora.lora_alpha, target_modules=self.cfg.lora.target_modules, lora_dropout=self.cfg.lora.lora_dropout, bias="none", task_type="CAUSAL_LM")
247
+ model = get_peft_model(model, lora_config)
248
+ model.print_trainable_parameters()
249
+
250
+ # 3. Create datasets
251
+ train_dataset = CausalJSONLDataset(train_data, self.loader.tokenizer, self.cfg.model.max_input_length)
252
+ val_dataset = CausalJSONLDataset(val_data, self.loader.tokenizer, self.cfg.model.max_input_length)
253
+
254
+ # 4. Train
255
+ trainer = Trainer(model=model, args=self._get_training_args(), train_dataset=train_dataset, eval_dataset=val_dataset, callbacks=callbacks)
256
+ trainer.train()
257
+ model.save_pretrained(self.cfg.output_dir)
258
+ return True, f"آموزش LoRA تکمیل شد. Adapterها در '{self.cfg.output_dir}' ذخیره شدند."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
  # ==========================
261
+ # DECOUPLED APPLICATION LOGIC
262
  # ==========================
263
+ class MahoonCore:
264
  def __init__(self, system_config: Optional[SystemConfig] = None):
265
  self.cfg = system_config or SystemConfig()
266
  self.rag = LegalRAGSystem(self.cfg)
 
267
  self._current_loader: Optional[ModelLoader] = None
268
  self._current_generator: Optional[UnifiedGenerator] = None
269
  self._lock = threading.Lock()
270
 
271
+ def load_model_and_rag(self, model_choice: str, use_rag: bool) -> str:
 
272
  with self._lock:
273
  try:
274
+ model_name, arch = self._get_model_configs()[model_choice]
275
+ self.cfg.model.model_name, self.cfg.model.architecture = model_name, arch
276
+
277
+ self._current_loader = ModelCache.get_model(model_name, arch, self.cfg.model)
 
 
278
  self._current_generator = UnifiedGenerator(self._current_loader)
279
+ model_msg = f"مدل بارگذاری شد: {model_name}"
280
 
281
+ rag_msg = ""
282
+ if use_rag: _, rag_msg = self.rag.load_chroma()
283
+
284
+ return f"{model_msg}\n{rag_msg}"
285
+ except Exception as e: return f"خطا در بارگذاری: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
+ def generate_response(self, question: str, use_rag: bool) -> Tuple[str, str, str]:
288
+ if not question or not self._current_generator: return "", "", ""
289
+
290
  start_time = time.time()
291
+ context, articles = "", []
292
+ if use_rag and self.rag.collection:
293
+ articles = self.rag.retrieve(question)
294
+ context = self.rag.build_context(articles)
295
+
296
+ response = self._current_generator.generate(question, context)
297
+
298
+ references = "\n\n".join([f"**شباهت: {art['similarity']:.2f}**\n{art['text'][:400]}..." for art in articles[:3]])
299
+ metrics = f"زمان پردازش: {time.time() - start_time:.2f}s | اسناد یافت شده: {len(articles)}"
300
+ return response, references, metrics
301
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  def _get_model_configs(self) -> Dict[str, Tuple[str, str]]:
 
303
  return {
304
+ "Causal (Llama-3.1-8B-Instruct)": ("meta-llama/Meta-Llama-3.1-8B-Instruct", "causal"),
305
+ "Causal (Mistral-7B-Instruct)": ("mistralai/Mistral-7B-Instruct-v0.2", "causal"),
 
306
  "Causal (PersianMind-v1.0)": ("universitytehran/PersianMind-v1.0", "causal"),
 
 
307
  }
308
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
 
310
+ # ==========================
311
+ # UI-ONLY CLASS
312
+ # ==========================
313
+ class LegalAppUI:
314
+ def __init__(self, core: MahoonCore):
315
+ self.core = core
316
+ self.model_choices = list(core._get_model_configs().keys())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
+ def build_ui(self) -> gr.Blocks:
319
+ with gr.Blocks(title="ماحون مشاور حقوقی هوشمند", theme=gr.themes.Soft()) as app:
320
+ gr.HTML("<h1>ماحون — مشاور حقوقی هوشمند 🏛️</h1>")
321
+
322
+ with gr.Tabs():
323
+ with gr.Tab("💬 مشاوره"):
324
  with gr.Row():
325
+ with gr.Column(scale=3):
326
+ chatbot = gr.Chatbot(label="گفتگو", height=550, avatar_images=("user.png", "bot.png"))
327
+ question_input = gr.Textbox(label="سوال خود را اینجا تایپ کنید...", placeholder="مثال: شرایط فسخ قرارداد اجاره چیست؟", scale=7)
328
+ submit_btn = gr.Button("🔍 ارسال", variant="primary", scale=1)
329
  with gr.Column(scale=1):
330
+ model_dropdown = gr.Dropdown(choices=self.model_choices, value=self.model_choices[0], label="انتخاب مدل")
331
+ use_rag_checkbox = gr.Checkbox(value=True, label="استفاده از RAG (جستجوی اسناد)")
332
+ load_model_btn = gr.Button("🚀 بارگذاری مدل", variant="secondary")
333
+ load_status = gr.Textbox(label="وضعیت", interactive=False)
334
+ with gr.Accordion("اسناد و منابع مرتبط", open=False):
335
+ references_output = gr.Markdown()
336
+ metrics_output = gr.Textbox(label="معیارهای عملکرد", interactive=False)
337
+
338
+ # --- Event Handlers ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  load_model_btn.click(
340
+ fn=self.ui_load_model,
341
+ inputs=[model_dropdown, use_rag_checkbox, load_model_btn],
342
+ outputs=[load_status, load_model_btn]
343
  )
344
 
345
  submit_btn.click(
346
+ fn=self.ui_generate_response,
347
+ inputs=[question_input, chatbot, use_rag_checkbox, submit_btn],
348
+ outputs=[chatbot, question_input, references_output, metrics_output, submit_btn]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
  )
350
+ question_input.submit(
351
+ fn=self.ui_generate_response,
352
+ inputs=[question_input, chatbot, use_rag_checkbox, submit_btn],
353
+ outputs=[chatbot, question_input, references_output, metrics_output, submit_btn]
 
 
 
 
 
 
 
 
 
 
 
 
354
  )
355
 
356
  return app
357
 
358
+ # --- UI Handler Methods ---
359
+ def ui_load_model(self, model_choice, use_rag, btn):
360
+ yield "در حال بارگذاری...", gr.update(interactive=False)
361
+ status = self.core.load_model_and_rag(model_choice, use_rag)
362
+ yield status, gr.update(interactive=True)
363
+
364
+ def ui_generate_response(self, question, chat_history, use_rag, btn):
365
+ if not question.strip():
366
+ chat_history.append((question, "لطفا سوال خود را وارد کنید."))
367
+ yield chat_history, "", "", "", gr.update(interactive=True)
368
+ return
369
 
370
+ # Show user's question immediately
371
+ chat_history.append([question, None])
372
+ yield chat_history, "", "", "", gr.update(interactive=False, value="...")
373
+
374
+ # Stream a simple "thinking" animation
375
+ for _ in range(3):
376
+ chat_history[-1][1] = "در حال پردازش" + "." * (_ + 1)
377
+ yield chat_history, "", "...", "...", gr.update(interactive=False, value="...")
378
+ time.sleep(0.3)
379
+
380
+ # Get the actual response
381
+ response, refs, metrics = self.core.generate_response(question, use_rag)
382
+ chat_history[-1][1] = response
383
+
384
+ yield chat_history, "", refs, metrics, gr.update(interactive=True, value="🔍 ارسال")
385
+
386
+
387
+ def main():
388
+ # Set a default system config (can be loaded from a file too)
389
+ config = SystemConfig()
390
+
391
+ # Initialize the core logic of the application
392
+ core = MahoonCore(config)
393
+
394
+ # Build the UI, passing the core logic to it
395
+ ui = LegalAppUI(core)
396
+ app = ui.build_ui()
397
+
398
+ # Launch the Gradio app
399
+ app.launch(server_name="0.0.0.0", server_port=7860, inbrowser=True)
400
 
401
  if __name__ == "__main__":
402
+ main()