umair-imran commited on
Commit
c03d371
·
verified ·
1 Parent(s): 7225dae

Create chain.py

Browse files
Files changed (1) hide show
  1. chain.py +53 -0
chain.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Any
2
+ from types import SimpleNamespace
3
+
4
+ from langchain.prompts import PromptTemplate
5
+ from langchain.chains.llm import LLMChain
6
+ from langchain_core.outputs import ChatGeneration
7
+
8
+ from langchain_google_genai import ChatGoogleGenerativeAI
9
+ from langchain_openai import ChatOpenAI
10
+
11
+ class GeminiChain(LLMChain):
12
+ def __init__(self, **kwargs) -> None:
13
+ template = kwargs.pop('template')
14
+ model_name = kwargs.pop('model_name')
15
+ temperature = kwargs.pop('temperature', 0.3)
16
+ api_key = kwargs.pop('api_key')
17
+ verbose = kwargs.pop('verbose', False)
18
+ input_variables = kwargs.pop('input_variables', ['query',
19
+ "history"])
20
+
21
+ output_parser = kwargs.pop('output_parser', None)
22
+ # create Gemini LLM object
23
+ # llm = ChatGoogleGenerativeAI(model=model_name,
24
+ # temperature=temperature,
25
+ # google_api_key=api_key)
26
+ llm = ChatOpenAI(api_key=api_key, model=model_name, temperature=temperature)
27
+ llm._get_llm_string()
28
+ # Define a prompt template
29
+ prompt_template = PromptTemplate(input_variables=input_variables,
30
+ template=template,
31
+ output_parser=output_parser)
32
+
33
+ super().__init__(llm=llm,
34
+ prompt=prompt_template,
35
+ verbose=verbose, **kwargs)
36
+
37
+ @classmethod
38
+ def from_system_prompt(self, template: str,
39
+ input_variables: List[str],
40
+ output_parser=None, **kwargs) -> LLMChain:
41
+ # create Gemini LLM object
42
+ llm_chain = self(template=template,
43
+ input_variables=input_variables,
44
+ output_parser=output_parser, **kwargs)
45
+ return llm_chain
46
+
47
+ def _get_llm_string(self, stop: Optional[List[str]] = None, **kwargs: Any) -> str:
48
+ return self._get_llm_string(stop=stop, **kwargs)
49
+
50
+ def _infer(self, inputs: dict) -> ChatGeneration:
51
+ # Run LLM chain
52
+ llm_output = self.generate([inputs]).generations[0][0]
53
+ return llm_output