doinglean commited on
Commit
d5e4a59
·
verified ·
1 Parent(s): d2e2940

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -11
app.py CHANGED
@@ -1,41 +1,64 @@
1
  import gradio as gr
2
  from huggingface_hub import hf_hub_download
3
  from ultralytics import YOLO
4
- import torch
 
 
5
 
6
  # Lade das Modell
7
  model_path = hf_hub_download(repo_id="foduucom/stockmarket-pattern-detection-yolov8", filename="model.pt")
8
  model = YOLO(model_path)
9
 
10
  def analyze_image(image, prompt):
11
- # Verwende den Prompt (falls nötig, hier als Kontext für die Verarbeitung)
12
- # YOLOv8 ignoriert den Prompt direkt, daher speichern wir ihn für die Logik
13
- results = model.predict(source=image, save=False)
 
 
 
 
 
14
  detections = []
15
  for result in results:
16
  for box in result.boxes:
17
  label = result.names[int(box.cls)]
18
  confidence = float(box.conf)
19
- # Farben basierend auf Label oder Prompt (z. B. "bullish" für grün)
20
- color = "green" if "bullish" in prompt.lower() or "Bullish" in label else "red"
 
 
 
 
 
 
 
21
  detections.append({
22
  "pattern": label,
23
  "confidence": confidence,
24
- "color": color,
25
- "prompt_used": prompt # Rückgabe des Prompts zur Überprüfung
 
26
  })
 
 
 
 
 
 
 
 
27
  return detections
28
 
29
- # Erstelle Gradio-Schnittstelle mit Bild- und Text-Eingabe
30
  iface = gr.Interface(
31
  fn=analyze_image,
32
  inputs=[
33
  gr.Image(type="pil", label="Upload TradingView Screenshot"),
34
- gr.Textbox(label="Prompt", placeholder="Enter your prompt, e.g., 'Detect candlestick patterns and colors'")
35
  ],
36
  outputs="json",
37
  title="Candlestick Pattern Detection",
38
- description="Upload a TradingView screenshot and provide a prompt to detect candlestick patterns and colors."
39
  )
40
 
41
  # Starte die App
 
1
  import gradio as gr
2
  from huggingface_hub import hf_hub_download
3
  from ultralytics import YOLO
4
+ import cv2
5
+ import numpy as np
6
+ import re
7
 
8
  # Lade das Modell
9
  model_path = hf_hub_download(repo_id="foduucom/stockmarket-pattern-detection-yolov8", filename="model.pt")
10
  model = YOLO(model_path)
11
 
12
  def analyze_image(image, prompt):
13
+ # Konvertiere PIL-Bild zu OpenCV-Format
14
+ image_np = np.array(image)
15
+ image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
16
+
17
+ # Führe Objekterkennung durch
18
+ results = model.predict(source=image_np, save=False)
19
+
20
+ # Extrahiere Kerzen
21
  detections = []
22
  for result in results:
23
  for box in result.boxes:
24
  label = result.names[int(box.cls)]
25
  confidence = float(box.conf)
26
+ # Extrahiere Bounding-Box-Koordinaten
27
+ xmin, ymin, xmax, ymax = box.xyxy[0].tolist()
28
+
29
+ # Schneide die Kerze aus dem Bild für Farbanalyse
30
+ candle_roi = image_cv[int(ymin):int(ymax), int(xmin):int(xmax)]
31
+ # Berechne dominante Farbe (RGB)
32
+ mean_color = np.mean(candle_roi, axis=(0, 1)).astype(int)
33
+ color_rgb = f"RGB({mean_color[2]},{mean_color[1]},{mean_color[0]})" # RGB-Format
34
+
35
  detections.append({
36
  "pattern": label,
37
  "confidence": confidence,
38
+ "color": color_rgb,
39
+ "x_center": (xmin + xmax) / 2, # Für Sortierung nach Position
40
+ "prompt_used": prompt
41
  })
42
+
43
+ # Sortiere Kerzen nach x-Position (von rechts nach links = neueste zuerst)
44
+ detections = sorted(detections, key=lambda x: x["x_center"], reverse=True)
45
+
46
+ # Begrenze auf die letzten 10 Kerzen, wenn im Prompt gefordert
47
+ if "last 10 candles" in prompt.lower():
48
+ detections = detections[:10]
49
+
50
  return detections
51
 
52
+ # Erstelle Gradio-Schnittstelle
53
  iface = gr.Interface(
54
  fn=analyze_image,
55
  inputs=[
56
  gr.Image(type="pil", label="Upload TradingView Screenshot"),
57
+ gr.Textbox(label="Prompt", placeholder="Enter your prompt, e.g., 'List last 10 candles with their colors'")
58
  ],
59
  outputs="json",
60
  title="Candlestick Pattern Detection",
61
+ description="Upload a TradingView screenshot and provide a prompt to detect candlestick patterns and their colors."
62
  )
63
 
64
  # Starte die App