Spaces:
Sleeping
Sleeping
Create chain.py
Browse files
chain.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Any
|
2 |
+
from types import SimpleNamespace
|
3 |
+
|
4 |
+
from langchain.prompts import PromptTemplate
|
5 |
+
from langchain.chains.llm import LLMChain
|
6 |
+
from langchain_core.outputs import ChatGeneration
|
7 |
+
|
8 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
9 |
+
from langchain_openai import ChatOpenAI
|
10 |
+
|
11 |
+
class GeminiChain(LLMChain):
|
12 |
+
def __init__(self, **kwargs) -> None:
|
13 |
+
template = kwargs.pop('template')
|
14 |
+
model_name = kwargs.pop('model_name')
|
15 |
+
temperature = kwargs.pop('temperature', 0.3)
|
16 |
+
api_key = kwargs.pop('api_key')
|
17 |
+
verbose = kwargs.pop('verbose', False)
|
18 |
+
input_variables = kwargs.pop('input_variables', ['query',
|
19 |
+
"history"])
|
20 |
+
|
21 |
+
output_parser = kwargs.pop('output_parser', None)
|
22 |
+
# create Gemini LLM object
|
23 |
+
# llm = ChatGoogleGenerativeAI(model=model_name,
|
24 |
+
# temperature=temperature,
|
25 |
+
# google_api_key=api_key)
|
26 |
+
llm = ChatOpenAI(api_key=api_key, model=model_name, temperature=temperature)
|
27 |
+
llm._get_llm_string()
|
28 |
+
# Define a prompt template
|
29 |
+
prompt_template = PromptTemplate(input_variables=input_variables,
|
30 |
+
template=template,
|
31 |
+
output_parser=output_parser)
|
32 |
+
|
33 |
+
super().__init__(llm=llm,
|
34 |
+
prompt=prompt_template,
|
35 |
+
verbose=verbose, **kwargs)
|
36 |
+
|
37 |
+
@classmethod
|
38 |
+
def from_system_prompt(self, template: str,
|
39 |
+
input_variables: List[str],
|
40 |
+
output_parser=None, **kwargs) -> LLMChain:
|
41 |
+
# create Gemini LLM object
|
42 |
+
llm_chain = self(template=template,
|
43 |
+
input_variables=input_variables,
|
44 |
+
output_parser=output_parser, **kwargs)
|
45 |
+
return llm_chain
|
46 |
+
|
47 |
+
def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
|
48 |
+
return self._get_llm_string(stop=stop, **kwargs)
|
49 |
+
|
50 |
+
def _infer(self, inputs: dict) -> ChatGeneration:
|
51 |
+
# Run LLM chain
|
52 |
+
llm_output = self.generate([inputs]).generations[0][0]
|
53 |
+
return llm_output
|