dots-ocr-space / app.py
wanifuck's picture
fix: image processing logic for PIL Image input
71e3f3f
"""
HuggingFace Space for dots.ocr (GOT-OCR2_0)
高精度OCRモデルをAPIとして提供
"""
import gradio as gr
import torch
import os
import io
import base64
import json
import time
from PIL import Image
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
import logging
# ロギング設定
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# GPU使用可能性チェック
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"使用デバイス: {device}")
# グローバル変数
model = None
tokenizer = None
def load_model():
"""dots.ocrモデルを読み込み"""
global model, tokenizer
try:
logger.info("dots.ocr (GOT-OCR2_0) モデルを読み込み中...")
# 8bit量子化設定
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
bnb_8bit_compute_dtype=torch.float16
)
# モデルとトークナイザーを読み込み(最大メモリ効率化)
model = AutoModel.from_pretrained(
'ucaslcl/GOT-OCR2_0',
trust_remote_code=True,
low_cpu_mem_usage=True,
device_map='auto',
use_safetensors=True,
torch_dtype=torch.float16, # メモリ使用量を半減
quantization_config=quantization_config, # 現代的な量子化設定
pad_token_id=151643
).eval()
tokenizer = AutoTokenizer.from_pretrained(
'ucaslcl/GOT-OCR2_0',
trust_remote_code=True
)
logger.info("モデル読み込み完了")
return True
except Exception as e:
logger.error(f"モデル読み込みエラー: {e}")
return False
def process_image(image, ocr_type="ocr", ocr_box="", ocr_color=""):
"""
画像をOCR処理
Args:
image: PIL Image または画像パス
ocr_type: OCRタイプ("ocr", "format", "fine-grained")
ocr_box: OCRボックス座標(オプション)
ocr_color: OCR色指定(オプション)
Returns:
dict: OCR結果
"""
global model, tokenizer
start_time = time.time()
try:
# モデル未読み込みの場合は読み込み
if model is None or tokenizer is None:
if not load_model():
raise Exception("モデルの読み込みに失敗しました")
# 画像処理
if isinstance(image, str):
# Base64文字列の場合
if image.startswith('data:image'):
image = image.split(',')[1]
image_data = base64.b64decode(image)
image = Image.open(io.BytesIO(image_data))
elif not isinstance(image, Image.Image):
# その他の形式の場合はPIL Imageに変換
image = Image.open(image)
# PIL ImageをRGB形式に変換
if image.mode != 'RGB':
image = image.convert('RGB')
logger.info(f"画像サイズ: {image.size}")
# OCR処理実行
with torch.no_grad():
result = model.chat(
tokenizer,
image,
ocr_type=ocr_type,
ocr_box=ocr_box,
ocr_color=ocr_color
)
processing_time = time.time() - start_time
logger.info(f"OCR処理完了: {processing_time:.2f}秒, 結果長: {len(result)}文字")
return {
"text": result,
"confidence": 0.95, # dots.ocrは高精度なので固定値
"processing_time": processing_time,
"model_used": "ucaslcl/GOT-OCR2_0",
"device": str(device),
"image_size": list(image.size)
}
except Exception as e:
logger.error(f"OCR処理エラー: {e}")
processing_time = time.time() - start_time
return {
"text": f"[エラー] OCR処理でエラーが発生しました: {str(e)}",
"confidence": 0.0,
"processing_time": processing_time,
"model_used": "error",
"device": str(device),
"error": str(e)
}
def gradio_interface(image, ocr_type="ocr"):
"""Gradio用のインターフェース関数"""
result = process_image(image, ocr_type=ocr_type)
# 結果を整形して返す
output_text = result["text"]
# メタデータ情報を追加
metadata = f"""
処理時間: {result['processing_time']:.2f}
信頼度: {result['confidence']:.1%}
使用モデル: {result['model_used']}
デバイス: {result['device']}
"""
if 'image_size' in result:
metadata += f"画像サイズ: {result['image_size'][0]}x{result['image_size'][1]}"
return output_text, metadata, json.dumps(result, ensure_ascii=False, indent=2)
def api_interface(image):
"""API用のインターフェース関数(JSON返却)"""
result = process_image(image)
return result
# Gradio インターフェース設定
with gr.Blocks(
title="dots.ocr (GOT-OCR2_0) - 高精度OCR API",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1200px !important;
}
"""
) as demo:
gr.Markdown("""
# 🔍 dots.ocr (GOT-OCR2_0) - 高精度OCR API
最先端の視覚言語モデルによる高精度OCR処理
- **多言語対応**: 日本語、英語、中国語など80以上の言語
- **レイアウト検出**: テキスト、テーブル、図表の構造認識
- **高精度**: 95%以上の認識精度
## 使用方法
1. 画像をアップロード
2. OCRタイプを選択
3. 「処理開始」ボタンをクリック
""")
with gr.Row():
with gr.Column(scale=1):
# 入力部分
image_input = gr.Image(
type="pil",
label="📷 画像をアップロード",
height=400
)
ocr_type = gr.Dropdown(
choices=["ocr", "format", "fine-grained"],
value="ocr",
label="🔧 OCRタイプ",
info="ocr: 基本OCR, format: フォーマット保持, fine-grained: 詳細解析"
)
process_btn = gr.Button("🚀 処理開始", variant="primary")
with gr.Column(scale=2):
# 出力部分
with gr.Tab("📄 テキスト結果"):
text_output = gr.Textbox(
label="抽出されたテキスト",
lines=15,
placeholder="ここに抽出されたテキストが表示されます..."
)
with gr.Tab("📊 処理情報"):
metadata_output = gr.Textbox(
label="処理メタデータ",
lines=8,
placeholder="処理時間、信頼度などの情報が表示されます..."
)
with gr.Tab("🔧 JSON結果"):
json_output = gr.Code(
label="完全なJSON結果",
language="json"
)
# 処理ボタンのイベント設定
process_btn.click(
fn=gradio_interface,
inputs=[image_input, ocr_type],
outputs=[text_output, metadata_output, json_output]
)
# API用のシンプルなエンドポイント(独立したInterface)
with gr.Row():
gr.Markdown("# API Endpoint")
with gr.Row():
gr.Markdown("このエンドポイントはプログラムからの呼び出し用です")
# API専用のInterface
api_image = gr.Image(type="pil", label="image")
api_submit = gr.Button("Submit")
api_output = gr.JSON(label="output")
# API用の関数
api_submit.click(
fn=api_interface,
inputs=[api_image],
outputs=[api_output],
api_name="predict"
)
# アプリケーション起動時にモデルを読み込み
if __name__ == "__main__":
logger.info("アプリケーション起動中...")
# 環境情報表示
logger.info(f"PyTorch version: {torch.__version__}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
logger.info(f"CUDA version: {torch.version.cuda}")
logger.info(f"GPU count: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
# モデル事前読み込み
load_model()
# Gradioアプリ起動
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_api=True,
show_error=True # エラー詳細表示を有効化
)