Mohansai2004 commited on
Commit
ebffcc9
·
verified ·
1 Parent(s): 0290b84

Update app/caption_model.py

Browse files
Files changed (1) hide show
  1. app/caption_model.py +3 -44
app/caption_model.py CHANGED
@@ -1,8 +1,9 @@
 
1
  from transformers import pipeline
2
  from PIL import Image
3
 
4
  # Load object detection model
5
- MODEL_NAME = "hustvl/yolos-small"
6
  detector = pipeline("object-detection", model=MODEL_NAME)
7
 
8
  def caption_image(image: Image.Image):
@@ -11,49 +12,7 @@ def caption_image(image: Image.Image):
11
  raise ValueError("Input must be a valid PIL Image in RGB or grayscale format")
12
 
13
  # Run object detection
14
- from transformers import pipeline
15
- from PIL import Image
16
-
17
- # Load object detection model
18
- MODEL_NAME = "hustvl/yolos-small"
19
- detector = pipeline("object-detection", model=MODEL_NAME)
20
-
21
- def caption_image(image: Image.Image):
22
- # Validate input
23
- if not isinstance(image, Image.Image) or image.mode not in ('RGB', 'L'):
24
- raise ValueError("Input must be a valid PIL Image in RGB or grayscale format")
25
-
26
- # Run object detection with custom parameters
27
- results = detector(image, top_k=20, threshold=0.2)
28
-
29
- # Track highest score per object
30
- objects_dict = {}
31
- for result in results:
32
- label = result['label']
33
- score = result['score']
34
- if label in objects_dict:
35
- objects_dict[label] = max(objects_dict[label], score)
36
- else:
37
- objects_dict[label] = score
38
-
39
- # Build structured list of objects
40
- objects_list = [
41
- {"label": label, "score": round(score, 2)}
42
- for label, score in sorted(objects_dict.items(), key=lambda x: x[1], reverse=True)
43
- ]
44
-
45
- # Create readable caption
46
- detected_objects = [f"{obj['label']} ({obj['score']:.2f})" for obj in objects_list]
47
- caption = "Detected objects: " + ", ".join(detected_objects) if detected_objects else "No objects detected."
48
-
49
- # Highest confidence score
50
- max_confidence = max(objects_dict.values()) if objects_dict else 0.0
51
-
52
- return {
53
- "caption": caption,
54
- "objects": objects_list,
55
- "confidence": round(max_confidence, 2)
56
- }
57
 
58
  # Track highest score per object
59
  objects_dict = {}
 
1
+ # Track highest score per object
2
  from transformers import pipeline
3
  from PIL import Image
4
 
5
  # Load object detection model
6
+ MODEL_NAME = "facebook/detr-resnet-50"
7
  detector = pipeline("object-detection", model=MODEL_NAME)
8
 
9
  def caption_image(image: Image.Image):
 
12
  raise ValueError("Input must be a valid PIL Image in RGB or grayscale format")
13
 
14
  # Run object detection
15
+ results = detector(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Track highest score per object
18
  objects_dict = {}