krmin commited on
Commit
401cf69
·
1 Parent(s): a2e450d

Add model download script and local model files for Mahjong Soul Vision

Browse files

- Implemented a script to download and save the Mahjong Soul Vision model locally.
- Added configuration files including config.json, preprocessor_config.json, and model.safetensors to the local model directory.
- Configured image processing parameters and model architecture settings for image classification tasks.

__pycache__/tools.cpython-310.pyc ADDED
Binary file (10.1 kB). View file
 
debug_coordinates.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 座標デバッグ用スクリプト
3
+ 雀魂のウィンドウから手牌領域をキャプチャして確認
4
+ """
5
+ import pygetwindow as gw
6
+ from PIL import ImageGrab
7
+ import numpy as np
8
+ import cv2
9
+
10
+ # ウィンドウを取得
11
+ window_title = "雀魂"
12
+ try:
13
+ window = gw.getWindowsWithTitle(window_title)[0]
14
+ print(f"✓ ウィンドウ検出: {window.title}")
15
+ print(f" 位置: ({window.left}, {window.top})")
16
+ print(f" サイズ: {window.width}x{window.height}")
17
+ except IndexError:
18
+ print(f"✗ '{window_title}' が見つかりません")
19
+ exit(1)
20
+
21
+ # スクリーンショット取得
22
+ screenshot = ImageGrab.grab(bbox=(window.left, window.top, window.right, window.bottom), all_screens=True)
23
+ frame = np.array(screenshot)
24
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
25
+
26
+ # 手牌座標(live_feed.pyと同じ)
27
+ PLAYER_HAND_X = 105
28
+ PLAYER_HAND_Y = 759
29
+ PLAYER_HAND_W = 627
30
+ PLAYER_HAND_H = 84
31
+
32
+ # 座標が画面内か確認
33
+ if PLAYER_HAND_Y + PLAYER_HAND_H > frame.shape[0]:
34
+ print(f"\n⚠ 警告: 手牌のY座標が画面外です")
35
+ print(f" 画面の高さ: {frame.shape[0]}")
36
+ print(f" 手牌の範囲: Y {PLAYER_HAND_Y} - {PLAYER_HAND_Y + PLAYER_HAND_H}")
37
+ # 座標を修正
38
+ PLAYER_HAND_Y = frame.shape[0] - PLAYER_HAND_H - 10
39
+ print(f" 修正後のY座標: {PLAYER_HAND_Y}")
40
+
41
+ # 手牌領域を抽出
42
+ hand_region = frame[PLAYER_HAND_Y:PLAYER_HAND_Y+PLAYER_HAND_H,
43
+ PLAYER_HAND_X:PLAYER_HAND_X+PLAYER_HAND_W]
44
+
45
+ # 矩形を描画
46
+ debug_frame = frame.copy()
47
+ cv2.rectangle(debug_frame,
48
+ (PLAYER_HAND_X, PLAYER_HAND_Y),
49
+ (PLAYER_HAND_X + PLAYER_HAND_W, PLAYER_HAND_Y + PLAYER_HAND_H),
50
+ (0, 255, 0), 3)
51
+
52
+ # 保存
53
+ cv2.imwrite("debug_full_screen.png", debug_frame)
54
+ cv2.imwrite("debug_hand_region.png", hand_region)
55
+
56
+ print(f"\n保存完了:")
57
+ print(f" debug_full_screen.png - 全画面(緑の矩形が手牌領域)")
58
+ print(f" debug_hand_region.png - 手牌領域のみ")
59
+ print(f"\n手牌領域:")
60
+ print(f" X: {PLAYER_HAND_X}")
61
+ print(f" Y: {PLAYER_HAND_Y}")
62
+ print(f" 幅: {PLAYER_HAND_W}")
63
+ print(f" 高さ: {PLAYER_HAND_H}")
64
+ print(f" サイズ: {hand_region.shape}")
download_model.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # モデルをローカルにダウンロードして保存
2
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
3
+ import os
4
+
5
+ model_name = "krmin/mahjong_soul_vision"
6
+ local_model_path = "./vision_transformer_local"
7
+
8
+ print(f"モデルをダウンロード中: {model_name}")
9
+ print(f"保存先: {local_model_path}")
10
+
11
+ # モデルとプロセッサをダウンロード
12
+ processor = AutoImageProcessor.from_pretrained(model_name)
13
+ model = AutoModelForImageClassification.from_pretrained(model_name)
14
+
15
+ # ローカルに保存
16
+ print("ローカルに保存中...")
17
+ processor.save_pretrained(local_model_path)
18
+ model.save_pretrained(local_model_path)
19
+
20
+ print("✓ 完了!")
21
+ print(f"\n次回からは以下のように読み込めます:")
22
+ print(f'pipe = pipeline("image-classification", model="{local_model_path}", device=device)')
live_feed.py CHANGED
@@ -2,6 +2,7 @@
2
 
3
  # %%
4
  import time
 
5
 
6
  import cv2
7
  from PIL import Image, ImageGrab
@@ -18,6 +19,105 @@ import torch.nn as nn
18
  from safetensors.torch import load_file
19
  # Load model directly
20
  from transformers import AutoModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class ImprovedNN(nn.Module):
22
  def __init__(self, input_dim, output_dim):
23
  super(ImprovedNN, self).__init__()
@@ -45,38 +145,68 @@ class ImprovedNN(nn.Module):
45
 
46
 
47
  if torch.cuda.is_available():
48
- print("CUDA available")
49
  device = torch.device("cuda")
50
  else:
51
- print("No CUDA")
52
  device = torch.device("cpu")
53
 
54
- pipe = pipeline("image-classification", model="pjura/mahjong_soul_vision", device=device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  input_dim = 204
57
  output_dim = 34
58
- model = ImprovedNN(input_dim=input_dim, output_dim=output_dim)
59
 
60
  model_path = "model.safetensors"
61
  state_dict = load_file(model_path)
62
 
63
- model.load_state_dict(state_dict)
 
64
 
65
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
- model.to(device)
67
 
68
  global_debug = False
69
- model.to(device)
 
 
 
 
70
 
71
- PLAYER_HAND_X = 300
72
- PLAYER_HAND_Y = 1048 - 200
73
- PLAYER_HAND_W = 1250
74
- PLAYER_HAND_H = 200
 
 
75
 
76
- PLAYER_PON_X = 300 + PLAYER_HAND_W
77
- PLAYER_PON_Y = 1048 - 200
78
  PLAYER_PON_W = 200
79
- PLAYER_PON_H = 200
80
 
81
  PLAYER_THROW_X = 790
82
  PLAYER_THROW_Y = 1048 - 490
@@ -114,10 +244,13 @@ def nothing(x):
114
 
115
 
116
  # Get the window by its title. Adjust this to the title of the window you want to capture.
117
- window_title = "MahjongSoul"
118
  try:
119
  window = gw.getWindowsWithTitle(window_title)[0]
 
120
  except IndexError:
 
 
121
  raise Exception(f"No window with title '{window_title}' found.")
122
 
123
  if global_debug:
@@ -126,7 +259,7 @@ if global_debug:
126
  cv2.createTrackbar('Upper', 'Trackbars', 255, 255, nothing)
127
 
128
 
129
- def analyze_region(frame, x, y, w, h, lower=150, upper=255, debug=False):
130
  if global_debug:
131
  lower = cv2.getTrackbarPos('Lower', 'Trackbars')
132
  upper = cv2.getTrackbarPos('Upper', 'Trackbars')
@@ -148,8 +281,11 @@ def analyze_region(frame, x, y, w, h, lower=150, upper=255, debug=False):
148
  rois = [] # Liste zur Sammlung von Regionen von Interesse (ROIs)
149
  boxes_temp = [] # Temporäre Liste zur Sammlung von Bounding-Box-Koordinaten
150
  contours, _ = cv2.findContours(roi_threshed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 
151
  for contour in contours:
152
- if cv2.contourArea(contour) > 500:
 
 
153
  x_rect, y_rect, w_rect, h_rect = cv2.boundingRect(contour)
154
  aspect_ratio = w_rect / h_rect
155
 
@@ -174,10 +310,13 @@ def analyze_region(frame, x, y, w, h, lower=150, upper=255, debug=False):
174
  label = predictions[0]['label']
175
  prob = predictions[0]['score']
176
 
177
- if prob > 0.0:
 
178
  boxes.append(boxes_temp[idx]) # idx wird hier verwendet
179
  probs.append(prob)
180
  labels.append(label)
 
 
181
  return boxes, labels, probs
182
 
183
 
@@ -239,7 +378,7 @@ def click_hand_tile(all_boxes, frame):
239
 
240
  translated_tensor = translate_boxes_to_tensors(all_boxes)
241
  # Stellen Sie sicher, dass Ihre make_prediction Funktion die Rohwahrscheinlichkeiten zurückgibt
242
- probs = make_prediction(model, translated_tensor)
243
 
244
  # Sortieren Sie die Wahrscheinlichkeiten in absteigender Reihenfolge und erhalten Sie die Indizes
245
  sorted_indices = probs.argsort(descending=True)
@@ -308,10 +447,60 @@ def draw_boxes(frame, boxes, labels, probs):
308
 
309
  PLAYER_PON_X_TEMP = PLAYER_PON_X
310
  PLAYER_PON_W_TEMP = PLAYER_PON_W
311
- while True:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  screenshot = ImageGrab.grab(bbox=(window.left, window.top, window.right, window.bottom), all_screens=True)
313
  frame = np.array(screenshot)
314
  frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
 
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
  # Analyze regions and get boxes, labels, and probabilities
317
  player_pon_boxes, player_pon_labels, player_pon_probs = analyze_region(frame, PLAYER_PON_X_TEMP, PLAYER_PON_Y,
@@ -323,6 +512,13 @@ while True:
323
 
324
  player_hand_boxes, player_hand_labels, player_hand_probs = analyze_region(frame, PLAYER_HAND_X, PLAYER_HAND_Y,
325
  PLAYER_HAND_W_TEMP, PLAYER_HAND_H)
 
 
 
 
 
 
 
326
  player_throw_boxes, player_throw_labels, player_throw_probs = analyze_region(frame, PLAYER_THROW_X, PLAYER_THROW_Y,
327
  PLAYER_THROW_W, PLAYER_THROW_H)
328
  right_player_throw_boxes, right_player_throw_labels, right_player_throw_probs = analyze_region(frame,
@@ -340,38 +536,30 @@ while True:
340
  OPPOSITE_PLAYER_THROW_Y,
341
  OPPOSITE_PLAYER_THROW_W,
342
  OPPOSITE_PLAYER_THROW_H)
343
- # Draw bounding boxes, labels, and probabilities
344
- draw_boxes(frame, player_hand_boxes, player_hand_labels, player_hand_probs)
345
- draw_boxes(frame, player_pon_boxes, player_pon_labels, player_pon_probs)
346
- draw_boxes(frame, player_throw_boxes, player_throw_labels, player_throw_probs)
347
- draw_boxes(frame, right_player_throw_boxes, right_player_throw_labels, right_player_throw_probs)
348
- draw_boxes(frame, left_player_throw_boxes, left_player_throw_labels, left_player_throw_probs)
349
- draw_boxes(frame, opposite_player_throw_boxes, opposite_player_throw_labels, opposite_player_throw_probs)
350
-
351
- cv2.rectangle(frame, (PLAYER_HAND_X, PLAYER_HAND_Y),
352
- (PLAYER_HAND_X + PLAYER_HAND_W_TEMP, PLAYER_HAND_Y + PLAYER_HAND_H),
353
- (0, 255, 0), 2)
354
- cv2.rectangle(frame, (PLAYER_PON_X_TEMP, PLAYER_PON_Y),
355
- (PLAYER_PON_X_TEMP + PLAYER_PON_W_TEMP, PLAYER_PON_Y + PLAYER_PON_H),
356
- (255, 255, 0), 2)
357
-
358
- cv2.rectangle(frame, (PLAYER_THROW_X, PLAYER_THROW_Y),
359
- (PLAYER_THROW_X + PLAYER_THROW_W, PLAYER_THROW_Y + PLAYER_THROW_H),
360
- (0, 0, 255), 2)
361
-
362
- # Zeichnen Sie die Boxen für die anderen Spieler
363
- cv2.rectangle(frame, (RIGHT_PLAYER_THROW_X, RIGHT_PLAYER_THROW_Y),
364
- (RIGHT_PLAYER_THROW_X + RIGHT_PLAYER_THROW_W, RIGHT_PLAYER_THROW_Y + RIGHT_PLAYER_THROW_H),
365
- (255, 0, 0), 2) # Blaue Farbe für den rechten Spieler
366
- cv2.rectangle(frame, (LEFT_PLAYER_THROW_X, LEFT_PLAYER_THROW_Y),
367
- (LEFT_PLAYER_THROW_X + LEFT_PLAYER_THROW_W, LEFT_PLAYER_THROW_Y + LEFT_PLAYER_THROW_H),
368
- (0, 255, 255), 2) # Gelbe Farbe für den linken Spieler
369
- cv2.rectangle(frame, (OPPOSITE_PLAYER_THROW_X, OPPOSITE_PLAYER_THROW_Y),
370
- (
371
- OPPOSITE_PLAYER_THROW_X + OPPOSITE_PLAYER_THROW_W,
372
- OPPOSITE_PLAYER_THROW_Y + OPPOSITE_PLAYER_THROW_H),
373
- (255, 0, 255), 2) # Violette Farbe für den gegenüberliegenden Spieler
374
-
375
  all_boxes = {
376
  "player_hand": player_hand_boxes,
377
  "player_hand_labels": player_hand_labels,
@@ -387,28 +575,91 @@ while True:
387
  "opposite_player_throw_labels": opposite_player_throw_labels
388
  }
389
 
 
390
  if len(player_hand_labels) + len(player_pon_labels) >= 14:
391
- print("Your turn!")
392
- click_hand_tile(all_boxes, frame)
393
- time.sleep(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
 
395
- if global_debug:
396
- # Erstellt die Trackbars
397
- lower = cv2.getTrackbarPos('Lower', 'Trackbars')
398
- upper = cv2.getTrackbarPos('Upper', 'Trackbars')
399
 
400
- if global_debug:
401
- frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) # Graustufen-Frame
402
- _, frame_threshed = cv2.threshold(frame_gray, lower, upper, cv2.THRESH_BINARY)
403
- cv2.imshow("Full Frame Gray", frame_gray) # Zeigt den grauen Frame
404
- cv2.imshow("Full Frame Threshold", frame_threshed)
405
- else:
406
- cv2.imshow("Mahjong Tile Recognition v2", frame)
407
- # Break the loop if 'q' is pressed
408
- if cv2.waitKey(1) & 0xFF == ord('q'):
409
- break
410
- # time.sleep(2)
411
- cv2.destroyAllWindows()
412
 
413
  # %%
414
 
 
2
 
3
  # %%
4
  import time
5
+ import sys
6
 
7
  import cv2
8
  from PIL import Image, ImageGrab
 
19
  from safetensors.torch import load_file
20
  # Load model directly
21
  from transformers import AutoModel
22
+ from PyQt5.QtWidgets import QApplication, QLabel, QVBoxLayout, QWidget
23
+ from PyQt5.QtCore import Qt, QTimer
24
+ from PyQt5.QtGui import QFont
25
+
26
+ # 透明オーバーレイウィンドウクラス
27
+ class TransparentOverlay(QWidget):
28
+ def __init__(self):
29
+ super().__init__()
30
+ self.initUI()
31
+
32
+ def initUI(self):
33
+ # ウィンドウ設定
34
+ self.setWindowFlags(
35
+ Qt.WindowStaysOnTopHint | # 最前面
36
+ Qt.FramelessWindowHint | # フレームなし
37
+ Qt.Tool # タスクバーに表示しない
38
+ )
39
+ self.setAttribute(Qt.WA_TranslucentBackground) # 透明背景
40
+
41
+ # 位置とサイズ(左上)
42
+ self.setGeometry(10, 10, 400, 150)
43
+
44
+ # レイアウト
45
+ layout = QVBoxLayout()
46
+ layout.setContentsMargins(10, 10, 10, 10)
47
+
48
+ # 手牌ラベル
49
+ self.hand_label = QLabel("手牌: 雀魂で牌が配られるまで待機中...")
50
+ self.hand_label.setFont(QFont("Yu Gothic UI", 12, QFont.Bold))
51
+ self.hand_label.setStyleSheet("""
52
+ QLabel {
53
+ color: white;
54
+ background-color: rgba(0, 0, 0, 180);
55
+ padding: 8px;
56
+ border-radius: 5px;
57
+ }
58
+ """)
59
+ layout.addWidget(self.hand_label)
60
+
61
+ # 推奨打牌ラベル
62
+ self.recommendation_label = QLabel("推奨: -")
63
+ self.recommendation_label.setFont(QFont("Yu Gothic UI", 16, QFont.Bold))
64
+ self.recommendation_label.setStyleSheet("""
65
+ QLabel {
66
+ color: #FFD700;
67
+ background-color: rgba(0, 0, 0, 180);
68
+ padding: 10px;
69
+ border-radius: 5px;
70
+ border: 2px solid #FFD700;
71
+ }
72
+ """)
73
+ layout.addWidget(self.recommendation_label)
74
+
75
+ # ステータスラベル
76
+ self.status_label = QLabel("✓ 起動完了 | Space: 自動クリック | 更新: 0.2秒毎")
77
+ self.status_label.setFont(QFont("Yu Gothic UI", 9))
78
+ self.status_label.setStyleSheet("""
79
+ QLabel {
80
+ color: #00FF00;
81
+ background-color: rgba(0, 0, 0, 150);
82
+ padding: 5px;
83
+ border-radius: 3px;
84
+ }
85
+ """)
86
+ layout.addWidget(self.status_label)
87
+
88
+ self.setLayout(layout)
89
+
90
+ def update_hand(self, tiles):
91
+ """手牌を更新"""
92
+ if tiles:
93
+ self.hand_label.setText(f"手牌: {' '.join(tiles)}")
94
+
95
+ def update_recommendation(self, tile):
96
+ """推奨打牌を更新"""
97
+ if tile:
98
+ self.recommendation_label.setText(f"推奨: {tile}")
99
+ self.recommendation_label.setStyleSheet("""
100
+ QLabel {
101
+ color: #FF4444;
102
+ background-color: rgba(0, 0, 0, 200);
103
+ padding: 10px;
104
+ border-radius: 5px;
105
+ border: 3px solid #FF4444;
106
+ }
107
+ """)
108
+ else:
109
+ self.recommendation_label.setText("推奨: -")
110
+ self.recommendation_label.setStyleSheet("""
111
+ QLabel {
112
+ color: #FFD700;
113
+ background-color: rgba(0, 0, 0, 180);
114
+ padding: 10px;
115
+ border-radius: 5px;
116
+ border: 2px solid #FFD700;
117
+ }
118
+ """)
119
+
120
+
121
  class ImprovedNN(nn.Module):
122
  def __init__(self, input_dim, output_dim):
123
  super(ImprovedNN, self).__init__()
 
145
 
146
 
147
  if torch.cuda.is_available():
148
+ print("CUDA利用可能")
149
  device = torch.device("cuda")
150
  else:
151
+ print(" CUDA利用不可 - CPUモード")
152
  device = torch.device("cpu")
153
 
154
+ # モデル読み込み(ローカルキャッシュを優先)
155
+ print("モデル読み込み中...")
156
+ import os
157
+ local_model_path = "./vision_transformer_local"
158
+ model_name = "krmin/mahjong_soul_vision"
159
+
160
+ # ローカルにモデルがあればそれを使用、なければHuggingFaceから
161
+ if os.path.exists(local_model_path):
162
+ print(f" ローカルモデルを使用: {local_model_path}")
163
+ pipe = pipeline("image-classification", model=local_model_path, device=device)
164
+ else:
165
+ print(f" HuggingFaceからダウンロード: {model_name}")
166
+ print(" 初回は30-60秒かかります")
167
+ pipe = pipeline("image-classification", model=model_name, device=device)
168
+ # ダウンロード後、ローカルに保存
169
+ try:
170
+ print(" 次回用にローカル保存中...")
171
+ pipe.model.save_pretrained(local_model_path)
172
+ pipe.feature_extractor.save_pretrained(local_model_path)
173
+ print(f" ✓ ローカルに保存完了: {local_model_path}")
174
+ except Exception as e:
175
+ print(f" ⚠ ローカル保存失敗: {e}")
176
+
177
+ print(" ✓ Vision Transformer読み込み完了")
178
 
179
  input_dim = 204
180
  output_dim = 34
181
+ discard_model = ImprovedNN(input_dim=input_dim, output_dim=output_dim)
182
 
183
  model_path = "model.safetensors"
184
  state_dict = load_file(model_path)
185
 
186
+ discard_model.load_state_dict(state_dict)
187
+ print(" ✓ 打牌予測モデル読み込み完了")
188
 
189
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
190
+ discard_model.to(device)
191
 
192
  global_debug = False
193
+ discard_model.to(device)
194
+
195
+ # グローバル変数
196
+ window = None
197
+ window_title = "雀魂"
198
 
199
+ # 雀魂の手牌座標(実際の画面から確認済み)
200
+ # ウィンドウ相対座標: x=105, y=759, width=627, height=84
201
+ PLAYER_HAND_X = 105
202
+ PLAYER_HAND_Y = 759
203
+ PLAYER_HAND_W = 627
204
+ PLAYER_HAND_H = 84
205
 
206
+ PLAYER_PON_X = PLAYER_HAND_X + PLAYER_HAND_W
207
+ PLAYER_PON_Y = PLAYER_HAND_Y
208
  PLAYER_PON_W = 200
209
+ PLAYER_PON_H = 84
210
 
211
  PLAYER_THROW_X = 790
212
  PLAYER_THROW_Y = 1048 - 490
 
244
 
245
 
246
  # Get the window by its title. Adjust this to the title of the window you want to capture.
247
+ print(f"雀魂ウィンドウを検索中...")
248
  try:
249
  window = gw.getWindowsWithTitle(window_title)[0]
250
+ print(f" ✓ ウィンドウ検出: {window.title}")
251
  except IndexError:
252
+ print(f" ✗ エラー: '{window_title}' というタイトルのウィンドウが見つかりません")
253
+ print(f" 雀魂を起動してからもう一度お試しください")
254
  raise Exception(f"No window with title '{window_title}' found.")
255
 
256
  if global_debug:
 
259
  cv2.createTrackbar('Upper', 'Trackbars', 255, 255, nothing)
260
 
261
 
262
+ def analyze_region(frame, x, y, w, h, lower=100, upper=255, debug=False):
263
  if global_debug:
264
  lower = cv2.getTrackbarPos('Lower', 'Trackbars')
265
  upper = cv2.getTrackbarPos('Upper', 'Trackbars')
 
281
  rois = [] # Liste zur Sammlung von Regionen von Interesse (ROIs)
282
  boxes_temp = [] # Temporäre Liste zur Sammlung von Bounding-Box-Koordinaten
283
  contours, _ = cv2.findContours(roi_threshed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
284
+
285
  for contour in contours:
286
+ area = cv2.contourArea(contour)
287
+ # 最小面積を200に下げて小さい牌も認識
288
+ if area > 200:
289
  x_rect, y_rect, w_rect, h_rect = cv2.boundingRect(contour)
290
  aspect_ratio = w_rect / h_rect
291
 
 
310
  label = predictions[0]['label']
311
  prob = predictions[0]['score']
312
 
313
+ # 確率が85%以上の認識結果のみ採用(Vision Transformerは99.7%の精度)
314
+ if prob > 0.85:
315
  boxes.append(boxes_temp[idx]) # idx wird hier verwendet
316
  probs.append(prob)
317
  labels.append(label)
318
+ # デバッグ出力は手牌のみ(捨て牌は出力しない)
319
+ # print(f"認識: {label} ({prob*100:.1f}%)", end=" ")
320
  return boxes, labels, probs
321
 
322
 
 
378
 
379
  translated_tensor = translate_boxes_to_tensors(all_boxes)
380
  # Stellen Sie sicher, dass Ihre make_prediction Funktion die Rohwahrscheinlichkeiten zurückgibt
381
+ probs = make_prediction(discard_model, translated_tensor)
382
 
383
  # Sortieren Sie die Wahrscheinlichkeiten in absteigender Reihenfolge und erhalten Sie die Indizes
384
  sorted_indices = probs.argsort(descending=True)
 
447
 
448
  PLAYER_PON_X_TEMP = PLAYER_PON_X
449
  PLAYER_PON_W_TEMP = PLAYER_PON_W
450
+
451
+ # PyQt5アプリケーション初期化
452
+ print("UIを初期化中...")
453
+ app = QApplication(sys.argv)
454
+ overlay = TransparentOverlay()
455
+ overlay.show()
456
+ print(" ✓ 透明オーバーレイウィンドウを表示")
457
+
458
+ # グローバル変数で推奨牌を保持
459
+ current_recommendation = None
460
+ previous_hand_count = 0 # 前回の手牌枚数を記憶
461
+
462
+ print("\n" + "="*60)
463
+ print("起動完了!")
464
+ print("="*60)
465
+ print("左上の透明ウィンドウに手牌と推奨牌を表示します")
466
+ print("Spaceキー: 推奨牌を自動クリック")
467
+ print("Dキー: デバッグ用に画面キャプチャを保存")
468
+ print("二値化閾値: 100 (lower) - 明るい牌を検出")
469
+ print("認識閾値: 85% - 高精度のみ採用")
470
+ print("ウィンドウを閉じる: 終了")
471
+ print("="*60 + "\n")
472
+
473
+ def process_frame():
474
+ """フレーム処理とUI更新"""
475
+ global PLAYER_PON_X_TEMP, PLAYER_PON_W_TEMP, PLAYER_HAND_W_TEMP, current_recommendation, previous_hand_count, window
476
+
477
+ # ウィンドウ位置を毎回更新(ウィンドウが移動しても追従)
478
+ try:
479
+ old_window = window
480
+ window = gw.getWindowsWithTitle(window_title)[0]
481
+ # デバッグ: ウィンドウ位置が変わったら通知
482
+ if old_window and (old_window.left != window.left or old_window.top != window.top):
483
+ print(f"\nウィンドウ移動検出: ({old_window.left}, {old_window.top}) → ({window.left}, {window.top})")
484
+ except IndexError:
485
+ print("\r雀魂ウィンドウが見つかりません ", end="", flush=True)
486
+ return
487
+
488
  screenshot = ImageGrab.grab(bbox=(window.left, window.top, window.right, window.bottom), all_screens=True)
489
  frame = np.array(screenshot)
490
  frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
491
+
492
+ # デバッグ: 'd'キーでキャプチャを保存
493
+ if keyboard.is_pressed('d'):
494
+ timestamp = int(time.time())
495
+ filename = f"debug_capture_{timestamp}.png"
496
+ cv2.imwrite(filename, frame)
497
+ print(f"\n📷 キャプチャ保存: {filename} (座標: left={window.left}, top={window.top}, right={window.right}, bottom={window.bottom})")
498
+
499
+ # 手牌領域も保存
500
+ roi = frame[PLAYER_HAND_Y:PLAYER_HAND_Y + PLAYER_HAND_H, PLAYER_HAND_X:PLAYER_HAND_X + PLAYER_HAND_W]
501
+ cv2.imwrite(f"debug_hand_{timestamp}.png", roi)
502
+ print(f"📷 手牌領域保存: debug_hand_{timestamp}.png")
503
+ time.sleep(0.5) # 連続保存を防ぐ
504
 
505
  # Analyze regions and get boxes, labels, and probabilities
506
  player_pon_boxes, player_pon_labels, player_pon_probs = analyze_region(frame, PLAYER_PON_X_TEMP, PLAYER_PON_Y,
 
512
 
513
  player_hand_boxes, player_hand_labels, player_hand_probs = analyze_region(frame, PLAYER_HAND_X, PLAYER_HAND_Y,
514
  PLAYER_HAND_W_TEMP, PLAYER_HAND_H)
515
+
516
+ # 手牌認識の詳細をコンソールに出力
517
+ if len(player_hand_labels) > 0:
518
+ print(f"\n手牌検出: ", end="")
519
+ for i, label in enumerate(player_hand_labels):
520
+ print(f"{label}({player_hand_probs[i]*100:.1f}%) ", end="")
521
+
522
  player_throw_boxes, player_throw_labels, player_throw_probs = analyze_region(frame, PLAYER_THROW_X, PLAYER_THROW_Y,
523
  PLAYER_THROW_W, PLAYER_THROW_H)
524
  right_player_throw_boxes, right_player_throw_labels, right_player_throw_probs = analyze_region(frame,
 
536
  OPPOSITE_PLAYER_THROW_Y,
537
  OPPOSITE_PLAYER_THROW_W,
538
  OPPOSITE_PLAYER_THROW_H)
539
+
540
+ # UI更新: 手牌
541
+ current_hand_count = len(player_hand_labels) + len(player_pon_labels)
542
+
543
+ if len(player_hand_labels) > 0:
544
+ overlay.update_hand(player_hand_labels)
545
+ # シンプルな表示(タイムスタンプ付き)
546
+ hand_str = " ".join(player_hand_labels)
547
+ current_time = time.strftime("%H:%M:%S")
548
+
549
+ # 手牌枚数が変化したら通知
550
+ if current_hand_count != previous_hand_count:
551
+ if current_hand_count == 14:
552
+ print(f"\n★ツモ! 14枚になりました", end="")
553
+ print(f"\n[{current_time}] [{len(player_hand_labels)}枚] {hand_str} ", end="", flush=True)
554
+ previous_hand_count = current_hand_count
555
+ # else:
556
+ # print(f"\r[{len(player_hand_labels)}枚] {hand_str} ", end="", flush=True)
557
+ else:
558
+ # 手牌が認識されていない場合のデバッグ情報
559
+ print(f"\r⚠ 手牌未検出 (座標: x={PLAYER_HAND_X}, y={PLAYER_HAND_Y}, w={PLAYER_HAND_W_TEMP}, h={PLAYER_HAND_H}) 二値化閾値=100 ", end="", flush=True)
560
+
561
+ previous_hand_count = 0
562
+
 
 
 
 
 
 
 
 
563
  all_boxes = {
564
  "player_hand": player_hand_boxes,
565
  "player_hand_labels": player_hand_labels,
 
575
  "opposite_player_throw_labels": opposite_player_throw_labels
576
  }
577
 
578
+ # 推奨牌の計算と表示
579
  if len(player_hand_labels) + len(player_pon_labels) >= 14:
580
+ print(f"\n自分の番 (手牌:{len(player_hand_labels)}+ポン:{len(player_pon_labels)}={len(player_hand_labels)+len(player_pon_labels)}枚)", end="")
581
+
582
+ # 推奨牌を計算
583
+ try:
584
+ translated_tensor = translate_boxes_to_tensors(all_boxes)
585
+ probs = make_prediction(discard_model, translated_tensor)
586
+
587
+ # 最も確率の高い牌を取得
588
+ sorted_indices = probs.argsort(descending=True).squeeze()
589
+
590
+ # デバッグ: モデルの上位推奨を表示
591
+ print(f"\nモデル推奨TOP5: ", end="")
592
+ for i, idx in enumerate(sorted_indices[:5]):
593
+ top_idx = int(idx.item())
594
+ tile = translate_to_vision(top_idx)
595
+ prob = probs[0][top_idx].item() * 100
596
+ in_hand = "✓" if tile in player_hand_labels else "✗"
597
+ print(f"{tile}({prob:.1f}%{in_hand}) ", end="")
598
+
599
+ # 手牌に存在する牌の中から最も確率の高いものを選択
600
+ found_recommendation = False
601
+ for idx in sorted_indices[:10]: # 上位10個をチェック
602
+ top_idx = int(idx.item())
603
+ recommended_tile = translate_to_vision(top_idx)
604
+
605
+ # 手牌に存在する牌のみを推奨
606
+ if recommended_tile in player_hand_labels:
607
+ current_recommendation = recommended_tile
608
+ overlay.update_recommendation(recommended_tile)
609
+ found_recommendation = True
610
+ print(f" → 推奨:{recommended_tile}", end="")
611
+
612
+ # Spaceキーで自動クリック
613
+ if keyboard.is_pressed('space'):
614
+ overlay.status_label.setText("クリック中...")
615
+ overlay.status_label.setStyleSheet("""
616
+ QLabel {
617
+ color: #FF0000;
618
+ background-color: rgba(0, 0, 0, 150);
619
+ padding: 5px;
620
+ border-radius: 3px;
621
+ }
622
+ """)
623
+ # 実際のクリック処理
624
+ for i, label in enumerate(player_hand_labels):
625
+ if label == recommended_tile:
626
+ box = player_hand_boxes[i]
627
+ x, y = box[0] + (box[2] - box[0]) // 2, box[1] + (box[3] - box[1]) // 2
628
+ abs_x = window.left + x
629
+ abs_y = window.top + y
630
+ pyautogui.click(abs_x, abs_y)
631
+ print(f" クリック!", end="")
632
+ break
633
+ time.sleep(0.5)
634
+ overlay.status_label.setText("✓ 起動完了 | Space: 自動クリック | 更新: 0.2秒毎")
635
+ overlay.status_label.setStyleSheet("""
636
+ QLabel {
637
+ color: #00FF00;
638
+ background-color: rgba(0, 0, 0, 150);
639
+ padding: 5px;
640
+ border-radius: 3px;
641
+ }
642
+ """)
643
+ break
644
+
645
+ if not found_recommendation:
646
+ overlay.update_recommendation(None)
647
+ print(f"\n→ 手牌に該当する推奨牌が見つかりません(TOP10に手牌の牌なし)", end="")
648
+ except Exception as e:
649
+ print(f"\n推奨計算エラー: {e}", end="")
650
+ import traceback
651
+ traceback.print_exc()
652
+ overlay.update_recommendation(None)
653
+ else:
654
+ overlay.update_recommendation(None)
655
 
656
+ # タイマーでフレーム処理を実行(200ms間隔 = より高頻度で更新)
657
+ timer = QTimer()
658
+ timer.timeout.connect(process_frame)
659
+ timer.start(200)
660
 
661
+ # アプリケーション実行
662
+ sys.exit(app.exec_())
 
 
 
 
 
 
 
 
 
 
663
 
664
  # %%
665
 
vision_transformer_local/config.json ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ViTForImageClassification"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.0,
6
+ "dtype": "float32",
7
+ "encoder_stride": 16,
8
+ "hidden_act": "gelu",
9
+ "hidden_dropout_prob": 0.0,
10
+ "hidden_size": 768,
11
+ "id2label": {
12
+ "0": "1b",
13
+ "1": "1n",
14
+ "2": "1p",
15
+ "3": "2b",
16
+ "4": "2n",
17
+ "5": "2p",
18
+ "6": "3b",
19
+ "7": "3n",
20
+ "8": "3p",
21
+ "9": "4b",
22
+ "10": "4n",
23
+ "11": "4p",
24
+ "12": "5b",
25
+ "13": "5n",
26
+ "14": "5p",
27
+ "15": "6b",
28
+ "16": "6n",
29
+ "17": "6p",
30
+ "18": "7b",
31
+ "19": "7n",
32
+ "20": "7p",
33
+ "21": "8b",
34
+ "22": "8n",
35
+ "23": "8p",
36
+ "24": "9b",
37
+ "25": "9n",
38
+ "26": "9p",
39
+ "27": "ew",
40
+ "28": "gd",
41
+ "29": "nw",
42
+ "30": "rd",
43
+ "31": "sw",
44
+ "32": "wd",
45
+ "33": "ww"
46
+ },
47
+ "image_size": 224,
48
+ "initializer_range": 0.02,
49
+ "intermediate_size": 3072,
50
+ "label2id": {
51
+ "1b": "0",
52
+ "1n": "1",
53
+ "1p": "2",
54
+ "2b": "3",
55
+ "2n": "4",
56
+ "2p": "5",
57
+ "3b": "6",
58
+ "3n": "7",
59
+ "3p": "8",
60
+ "4b": "9",
61
+ "4n": "10",
62
+ "4p": "11",
63
+ "5b": "12",
64
+ "5n": "13",
65
+ "5p": "14",
66
+ "6b": "15",
67
+ "6n": "16",
68
+ "6p": "17",
69
+ "7b": "18",
70
+ "7n": "19",
71
+ "7p": "20",
72
+ "8b": "21",
73
+ "8n": "22",
74
+ "8p": "23",
75
+ "9b": "24",
76
+ "9n": "25",
77
+ "9p": "26",
78
+ "ew": "27",
79
+ "gd": "28",
80
+ "nw": "29",
81
+ "rd": "30",
82
+ "sw": "31",
83
+ "wd": "32",
84
+ "ww": "33"
85
+ },
86
+ "layer_norm_eps": 1e-12,
87
+ "model_type": "vit",
88
+ "num_attention_heads": 12,
89
+ "num_channels": 3,
90
+ "num_hidden_layers": 12,
91
+ "patch_size": 16,
92
+ "pooler_act": "tanh",
93
+ "pooler_output_size": 768,
94
+ "problem_type": "single_label_classification",
95
+ "qkv_bias": true,
96
+ "transformers_version": "4.57.1"
97
+ }
vision_transformer_local/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d8e6c235fbcb30498788fac92f880c5c004b7861c3f90599dc724616ae09efd
3
+ size 343322416
vision_transformer_local/preprocessor_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_convert_rgb": null,
3
+ "do_normalize": true,
4
+ "do_rescale": true,
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.5,
8
+ 0.5,
9
+ 0.5
10
+ ],
11
+ "image_processor_type": "ViTImageProcessor",
12
+ "image_std": [
13
+ 0.5,
14
+ 0.5,
15
+ 0.5
16
+ ],
17
+ "resample": 2,
18
+ "rescale_factor": 0.00392156862745098,
19
+ "size": {
20
+ "height": 224,
21
+ "width": 224
22
+ }
23
+ }