doinglean commited on
Commit
29e9e8c
·
verified ·
1 Parent(s): cb9b58d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -8
app.py CHANGED
@@ -4,21 +4,36 @@ from ultralytics import YOLO
4
  import cv2
5
  import numpy as np
6
  import easyocr
 
 
 
 
 
7
 
8
  # Lade das Modell
9
  model_path = hf_hub_download(repo_id="foduucom/stockmarket-pattern-detection-yolov8", filename="model.pt")
10
  model = YOLO(model_path)
11
 
12
  # OCR für Preise
13
- reader = easyocr.Reader(['en'])
14
 
15
  def analyze_image(image, prompt):
 
 
16
  # Konvertiere PIL-Bild zu OpenCV-Format
17
  image_np = np.array(image)
18
  image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
 
 
 
 
 
 
 
19
 
20
  # Führe Objekterkennung durch
21
- results = model.predict(source=image_np, save=False)
 
22
 
23
  # Extrahiere Kerzen
24
  detections = []
@@ -27,15 +42,23 @@ def analyze_image(image, prompt):
27
  label = result.names[int(box.cls)]
28
  confidence = float(box.conf)
29
  xmin, ymin, xmax, ymax = box.xyxy[0].tolist()
 
 
30
 
31
- # Extrahiere Farbe
32
  candle_roi = image_cv[int(ymin):int(ymax), int(xmin):int(xmax)]
 
 
 
33
  mean_color = np.mean(candle_roi, axis=(0, 1)).astype(int)
34
  color_rgb = f"RGB({mean_color[2]},{mean_color[1]},{mean_color[0]})"
35
 
36
- # OCR für Opening/Close-Preise (aus Achsen, anpassen an Chart)
37
- price_text = reader.readtext(image_cv[int(ymin):int(ymax), int(xmin):int(xmax)], detail=0)
 
 
38
  prices = ' '.join(price_text) if price_text else "No price detected"
 
39
 
40
  detections.append({
41
  "pattern": label,
@@ -47,13 +70,23 @@ def analyze_image(image, prompt):
47
 
48
  # Sortiere nach x-Position (rechts nach links = neueste Kerzen)
49
  detections = sorted(detections, key=lambda x: x["x_center"], reverse=True)
 
50
 
51
  # Begrenze auf die letzten 8 Kerzen
52
  if "last 8 candles" in prompt.lower() or "letzte 8 kerzen" in prompt.lower():
53
  detections = detections[:8]
54
 
55
- return detections
 
 
 
 
 
 
 
 
56
 
 
57
  iface = gr.Interface(
58
  fn=analyze_image,
59
  inputs=[
@@ -61,8 +94,8 @@ iface = gr.Interface(
61
  gr.Textbox(label="Prompt", placeholder="Enter your prompt, e.g., 'List last 8 candles with their colors'")
62
  ],
63
  outputs="json",
64
- title="Stock Chart Analysis",
65
- description="Upload a screenshot and provide a prompt to analyze candlesticks."
66
  )
67
 
68
  iface.launch()
 
4
  import cv2
5
  import numpy as np
6
  import easyocr
7
+ import logging
8
+
9
+ # Logging einrichten
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
 
13
  # Lade das Modell
14
  model_path = hf_hub_download(repo_id="foduucom/stockmarket-pattern-detection-yolov8", filename="model.pt")
15
  model = YOLO(model_path)
16
 
17
  # OCR für Preise
18
+ reader = easyocr.Reader(['en'], gpu=False)
19
 
20
  def analyze_image(image, prompt):
21
+ logger.info("Starting image analysis with prompt: %s", prompt)
22
+
23
  # Konvertiere PIL-Bild zu OpenCV-Format
24
  image_np = np.array(image)
25
  image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
26
+
27
+ # Bildvorverarbeitung: Kontrast erhöhen
28
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
29
+ gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
30
+ enhanced = clahe.apply(gray)
31
+ image_cv = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
32
+ logger.info("Image preprocessed: shape=%s", image_np.shape)
33
 
34
  # Führe Objekterkennung durch
35
+ results = model.predict(source=image_np, conf=0.3, iou=0.5, save=False)
36
+ logger.info("YOLO predictions: %d boxes detected", len(results[0].boxes))
37
 
38
  # Extrahiere Kerzen
39
  detections = []
 
42
  label = result.names[int(box.cls)]
43
  confidence = float(box.conf)
44
  xmin, ymin, xmax, ymax = box.xyxy[0].tolist()
45
+ logger.info("Detected: %s, confidence=%.2f, box=(%.0f, %.0f, %.0f, %.0f)",
46
+ label, confidence, xmin, ymin, xmax, ymax)
47
 
48
+ # Extrahiere Farbe (Fokus auf Kerzenkörper)
49
  candle_roi = image_cv[int(ymin):int(ymax), int(xmin):int(xmax)]
50
+ if candle_roi.size == 0:
51
+ logger.warning("Empty ROI for box: (%.0f, %.0f, %.0f, %.0f)", xmin, ymin, xmax, ymax)
52
+ continue
53
  mean_color = np.mean(candle_roi, axis=(0, 1)).astype(int)
54
  color_rgb = f"RGB({mean_color[2]},{mean_color[1]},{mean_color[0]})"
55
 
56
+ # OCR für Preise (erweitere ROI für Achsen)
57
+ price_roi = image_cv[max(0, int(ymin)-50):min(image_np.shape[0], int(ymax)+50),
58
+ max(0, int(xmin)-50):min(image_np.shape[1], int(xmax)+50)]
59
+ price_text = reader.readtext(price_roi, detail=0, allowlist='0123456789.')
60
  prices = ' '.join(price_text) if price_text else "No price detected"
61
+ logger.info("OCR prices: %s", prices)
62
 
63
  detections.append({
64
  "pattern": label,
 
70
 
71
  # Sortiere nach x-Position (rechts nach links = neueste Kerzen)
72
  detections = sorted(detections, key=lambda x: x["x_center"], reverse=True)
73
+ logger.info("Sorted detections: %d", len(detections))
74
 
75
  # Begrenze auf die letzten 8 Kerzen
76
  if "last 8 candles" in prompt.lower() or "letzte 8 kerzen" in prompt.lower():
77
  detections = detections[:8]
78
 
79
+ # Debugging: Wenn leer, gib Hinweis
80
+ if not detections:
81
+ logger.warning("No detections found. Check image quality or model configuration.")
82
+ return {"prompt": prompt, "description": "No candlesticks detected. Ensure clear image and visible candles."}
83
+
84
+ return {
85
+ "prompt": prompt,
86
+ "detections": detections
87
+ }
88
 
89
+ # Erstelle Gradio-Schnittstelle
90
  iface = gr.Interface(
91
  fn=analyze_image,
92
  inputs=[
 
94
  gr.Textbox(label="Prompt", placeholder="Enter your prompt, e.g., 'List last 8 candles with their colors'")
95
  ],
96
  outputs="json",
97
+ title="Stock Chart Analysis with YOLOv8",
98
+ description="Upload a TradingView screenshot to detect the last 8 candlesticks, their colors, and prices."
99
  )
100
 
101
  iface.launch()