Larrytech commited on
Commit
2d9d105
·
1 Parent(s): 4f6b66b

Build Update

Browse files
Files changed (1) hide show
  1. main.py +3 -3
main.py CHANGED
@@ -7,7 +7,7 @@ app = FastAPI()
7
 
8
  MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
9
 
10
- # Load tokenzier and model
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
12
  model = AutoModelForCausalLM.from_pretrained(
13
  MODEL_ID,
@@ -15,8 +15,8 @@ model = AutoModelForCausalLM.from_pretrained(
15
  device_map="auto"
16
  )
17
 
18
- # Use pipeline for easier text generation
19
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
20
 
21
  @app.get("/", response_class=HTMLResponse)
22
  def index():
 
7
 
8
  MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
9
 
10
+ # Load tokenizer and model
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
12
  model = AutoModelForCausalLM.from_pretrained(
13
  MODEL_ID,
 
15
  device_map="auto"
16
  )
17
 
18
+ # Use pipeline for easier text generation (device argument removed!)
19
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
20
 
21
  @app.get("/", response_class=HTMLResponse)
22
  def index():