Custom interface to the generativelanguage.googleapis.com API using
HTTPX and Pydantic.
The Google SDK for interacting with the generativelanguage.googleapis.com API
google-generativeai reads like it was written by a
Java developer who thought they knew everything about OOP, spent 30 minutes trying to learn Python,
gave up and decided to build the library to prove how horrible Python is. It also doesn't use httpx for HTTP requests,
and tries to implement tool calling itself, but doesn't use Pydantic or equivalent for validation.
We therefore implement support for the API directly.
Despite these shortcomings, the Gemini model is actually quite powerful and very fast.
Since Gemini supports a variety of date-stamped models, we explicitly list the latest models but
allow any name in the type hints.
See the Gemini API docs for a full list.
@dataclass(init=False)classGeminiModel(Model):"""A model that uses Gemini via `generativelanguage.googleapis.com` API. This is implemented from scratch rather than using a dedicated SDK, good API documentation is available [here](https://ai.google.dev/api). Apart from `__init__`, all methods are private or match those of the base class. """client:httpx.AsyncClient=field(repr=False)_model_name:GeminiModelName=field(repr=False)_provider:Literal['google-gla','google-vertex']|Provider[httpx.AsyncClient]|None=field(repr=False)_auth:AuthProtocol|None=field(repr=False)_url:str|None=field(repr=False)_system:str=field(default='gemini',repr=False)def__init__(self,model_name:GeminiModelName,*,provider:Literal['google-gla','google-vertex']|Provider[httpx.AsyncClient]='google-gla',):"""Initialize a Gemini model. Args: model_name: The name of the model to use. provider: The provider to use for authentication and API access. Can be either the string 'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`. If not provided, a new provider will be created using the other parameters. """self._model_name=model_nameself._provider=providerifisinstance(provider,str):provider=infer_provider(provider)self._system=provider.nameself.client=provider.clientself._url=str(self.client.base_url)@propertydefbase_url(self)->str:assertself._urlisnotNone,'URL not initialized'returnself._urlasyncdefrequest(self,messages:list[ModelMessage],model_settings:ModelSettings|None,model_request_parameters:ModelRequestParameters,)->tuple[ModelResponse,usage.Usage]:check_allow_model_requests()asyncwithself._make_request(messages,False,cast(GeminiModelSettings,model_settingsor{}),model_request_parameters)ashttp_response:response=_gemini_response_ta.validate_json(awaithttp_response.aread())returnself._process_response(response),_metadata_as_usage(response)@asynccontextmanagerasyncdefrequest_stream(self,messages:list[ModelMessage],model_settings:ModelSettings|None,model_request_parameters:ModelRequestParameters,)->AsyncIterator[StreamedResponse]:check_allow_model_requests()asyncwithself._make_request(messages,True,cast(GeminiModelSettings,model_settingsor{}),model_request_parameters)ashttp_response:yieldawaitself._process_streamed_response(http_response)@propertydefmodel_name(self)->GeminiModelName:"""The model name."""returnself._model_name@propertydefsystem(self)->str:"""The system / model provider."""returnself._systemdef_get_tools(self,model_request_parameters:ModelRequestParameters)->_GeminiTools|None:tools=[_function_from_abstract_tool(t)fortinmodel_request_parameters.function_tools]ifmodel_request_parameters.result_tools:tools+=[_function_from_abstract_tool(t)fortinmodel_request_parameters.result_tools]return_GeminiTools(function_declarations=tools)iftoolselseNonedef_get_tool_config(self,model_request_parameters:ModelRequestParameters,tools:_GeminiTools|None)->_GeminiToolConfig|None:ifmodel_request_parameters.allow_text_result:returnNoneeliftools:return_tool_config([t['name']fortintools['function_declarations']])else:return_tool_config([])@asynccontextmanagerasyncdef_make_request(self,messages:list[ModelMessage],streamed:bool,model_settings:GeminiModelSettings,model_request_parameters:ModelRequestParameters,)->AsyncIterator[HTTPResponse]:tools=self._get_tools(model_request_parameters)tool_config=self._get_tool_config(model_request_parameters,tools)sys_prompt_parts,contents=awaitself._message_to_gemini_content(messages)request_data=_GeminiRequest(contents=contents)ifsys_prompt_parts:request_data['system_instruction']=_GeminiTextContent(role='user',parts=sys_prompt_parts)iftoolsisnotNone:request_data['tools']=toolsiftool_configisnotNone:request_data['tool_config']=tool_configgeneration_config:_GeminiGenerationConfig={}ifmodel_settings:if(max_tokens:=model_settings.get('max_tokens'))isnotNone:generation_config['max_output_tokens']=max_tokensif(temperature:=model_settings.get('temperature'))isnotNone:generation_config['temperature']=temperatureif(top_p:=model_settings.get('top_p'))isnotNone:generation_config['top_p']=top_pif(presence_penalty:=model_settings.get('presence_penalty'))isnotNone:generation_config['presence_penalty']=presence_penaltyif(frequency_penalty:=model_settings.get('frequency_penalty'))isnotNone:generation_config['frequency_penalty']=frequency_penaltyif(gemini_safety_settings:=model_settings.get('gemini_safety_settings'))!=[]:request_data['safety_settings']=gemini_safety_settingsifgeneration_config:request_data['generation_config']=generation_configheaders={'Content-Type':'application/json','User-Agent':get_user_agent()}url=f'/{self._model_name}:{"streamGenerateContent"ifstreamedelse"generateContent"}'request_json=_gemini_request_ta.dump_json(request_data,by_alias=True)asyncwithself.client.stream('POST',url,content=request_json,headers=headers,timeout=model_settings.get('timeout',USE_CLIENT_DEFAULT),)asr:if(status_code:=r.status_code)!=200:awaitr.aread()ifstatus_code>=400:raiseModelHTTPError(status_code=status_code,model_name=self.model_name,body=r.text)raiseUnexpectedModelBehavior(f'Unexpected response from gemini {status_code}',r.text)yieldrdef_process_response(self,response:_GeminiResponse)->ModelResponse:iflen(response['candidates'])!=1:raiseUnexpectedModelBehavior('Expected exactly one candidate in Gemini response')if'content'notinresponse['candidates'][0]:ifresponse['candidates'][0].get('finish_reason')=='SAFETY':raiseUnexpectedModelBehavior('Safety settings triggered',str(response))else:raiseUnexpectedModelBehavior('Content field missing from Gemini response',str(response))parts=response['candidates'][0]['content']['parts']return_process_response_from_parts(parts,model_name=response.get('model_version',self._model_name))asyncdef_process_streamed_response(self,http_response:HTTPResponse)->StreamedResponse:"""Process a streamed response, and prepare a streaming response to return."""aiter_bytes=http_response.aiter_bytes()start_response:_GeminiResponse|None=Nonecontent=bytearray()asyncforchunkinaiter_bytes:content.extend(chunk)responses=_gemini_streamed_response_ta.validate_json(_ensure_decodeable(content),experimental_allow_partial='trailing-strings',)ifresponses:last=responses[-1]iflast['candidates']andlast['candidates'][0].get('content',{}).get('parts'):start_response=lastbreakifstart_responseisNone:raiseUnexpectedModelBehavior('Streamed response ended without content or tool calls')returnGeminiStreamedResponse(_model_name=self._model_name,_content=content,_stream=aiter_bytes)@classmethodasyncdef_message_to_gemini_content(cls,messages:list[ModelMessage])->tuple[list[_GeminiTextPart],list[_GeminiContent]]:sys_prompt_parts:list[_GeminiTextPart]=[]contents:list[_GeminiContent]=[]forminmessages:ifisinstance(m,ModelRequest):message_parts:list[_GeminiPartUnion]=[]forpartinm.parts:ifisinstance(part,SystemPromptPart):sys_prompt_parts.append(_GeminiTextPart(text=part.content))elifisinstance(part,UserPromptPart):message_parts.extend(awaitcls._map_user_prompt(part))elifisinstance(part,ToolReturnPart):message_parts.append(_response_part_from_response(part.tool_name,part.model_response_object()))elifisinstance(part,RetryPromptPart):ifpart.tool_nameisNone:message_parts.append(_GeminiTextPart(text=part.model_response()))else:response={'call_error':part.model_response()}message_parts.append(_response_part_from_response(part.tool_name,response))else:assert_never(part)ifmessage_parts:contents.append(_GeminiContent(role='user',parts=message_parts))elifisinstance(m,ModelResponse):contents.append(_content_model_response(m))else:assert_never(m)returnsys_prompt_parts,contents@staticmethodasyncdef_map_user_prompt(part:UserPromptPart)->list[_GeminiPartUnion]:ifisinstance(part.content,str):return[{'text':part.content}]else:content:list[_GeminiPartUnion]=[]foriteminpart.content:ifisinstance(item,str):content.append({'text':item})elifisinstance(item,BinaryContent):base64_encoded=base64.b64encode(item.data).decode('utf-8')content.append(_GeminiInlineDataPart(inline_data={'data':base64_encoded,'mime_type':item.media_type}))elifisinstance(item,(AudioUrl,ImageUrl,DocumentUrl)):client=cached_async_http_client()response=awaitclient.get(item.url,follow_redirects=True)response.raise_for_status()mime_type=response.headers['Content-Type'].split(';')[0]inline_data=_GeminiInlineDataPart(inline_data={'data':base64.b64encode(response.content).decode('utf-8'),'mime_type':mime_type})content.append(inline_data)else:assert_never(item)returncontent
The provider to use for authentication and API access. Can be either the string
'google-gla' or 'google-vertex' or an instance of Provider[httpx.AsyncClient].
If not provided, a new provider will be created using the other parameters.
'google-gla'
Source code in pydantic_ai_slim/pydantic_ai/models/gemini.py
def__init__(self,model_name:GeminiModelName,*,provider:Literal['google-gla','google-vertex']|Provider[httpx.AsyncClient]='google-gla',):"""Initialize a Gemini model. Args: model_name: The name of the model to use. provider: The provider to use for authentication and API access. Can be either the string 'google-gla' or 'google-vertex' or an instance of `Provider[httpx.AsyncClient]`. If not provided, a new provider will be created using the other parameters. """self._model_name=model_nameself._provider=providerifisinstance(provider,str):provider=infer_provider(provider)self._system=provider.nameself.client=provider.clientself._url=str(self.client.base_url)
Source code in pydantic_ai_slim/pydantic_ai/models/gemini.py
328329330331
classAuthProtocol(Protocol):"""Abstract definition for Gemini authentication."""asyncdefheaders(self)->dict[str,str]:...
ApiKeyAuthdataclass
Authentication using an API key for the X-Goog-Api-Key header.
Source code in pydantic_ai_slim/pydantic_ai/models/gemini.py
334335336337338339340341342
@dataclassclassApiKeyAuth:"""Authentication using an API key for the `X-Goog-Api-Key` header."""api_key:strasyncdefheaders(self)->dict[str,str]:# https://cloud.google.com/docs/authentication/api-keys-use#using-with-restreturn{'X-Goog-Api-Key':self.api_key}
@dataclassclassGeminiStreamedResponse(StreamedResponse):"""Implementation of `StreamedResponse` for the Gemini model."""_model_name:GeminiModelName_content:bytearray_stream:AsyncIterator[bytes]_timestamp:datetime=field(default_factory=_utils.now_utc,init=False)asyncdef_get_event_iterator(self)->AsyncIterator[ModelResponseStreamEvent]:asyncforgemini_responseinself._get_gemini_responses():candidate=gemini_response['candidates'][0]if'content'notincandidate:raiseUnexpectedModelBehavior('Streamed response has no content field')gemini_part:_GeminiPartUnionforgemini_partincandidate['content']['parts']:if'text'ingemini_part:# Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled# amongst the tool call deltasyieldself._parts_manager.handle_text_delta(vendor_part_id=None,content=gemini_part['text'])elif'function_call'ingemini_part:# Here, we assume all function_call parts are complete and don't have deltas.# We do this by assigning a unique randomly generated "vendor_part_id".# We need to confirm whether this is actually true, but if it isn't, we can still handle it properly# it would just be a bit more complicated. And we'd need to confirm the intended semantics.maybe_event=self._parts_manager.handle_tool_call_delta(vendor_part_id=uuid4(),tool_name=gemini_part['function_call']['name'],args=gemini_part['function_call']['args'],tool_call_id=None,)ifmaybe_eventisnotNone:yieldmaybe_eventelse:assert'function_response'ingemini_part,f'Unexpected part: {gemini_part}'asyncdef_get_gemini_responses(self)->AsyncIterator[_GeminiResponse]:# This method exists to ensure we only yield completed items, so we don't need to worry about# partial gemini responses, which would make everything more complicatedgemini_responses:list[_GeminiResponse]=[]current_gemini_response_index=0# Right now, there are some circumstances where we will have information that could be yielded sooner than it is# But changing that would make things a lot more complicated.asyncforchunkinself._stream:self._content.extend(chunk)gemini_responses=_gemini_streamed_response_ta.validate_json(_ensure_decodeable(self._content),experimental_allow_partial='trailing-strings',)# The idea: yield only up to the latest response, which might still be partial.# Note that if the latest response is complete, we could yield it immediately, but there's not a good# allow_partial API to determine if the last item in the list is complete.responses_to_yield=gemini_responses[:-1]forrinresponses_to_yield[current_gemini_response_index:]:current_gemini_response_index+=1self._usage+=_metadata_as_usage(r)yieldr# Now yield the final response, which should be completeifgemini_responses:r=gemini_responses[-1]self._usage+=_metadata_as_usage(r)yieldr@propertydefmodel_name(self)->GeminiModelName:"""Get the model name of the response."""returnself._model_name@propertydeftimestamp(self)->datetime:"""Get the timestamp of the response."""returnself._timestamp
classGeminiSafetySettings(TypedDict):"""Safety settings options for Gemini model request. See [Gemini API docs](https://ai.google.dev/gemini-api/docs/safety-settings) for safety category and threshold descriptions. For an example on how to use `GeminiSafetySettings`, see [here](../../agents.md#model-specific-settings). """category:Literal['HARM_CATEGORY_UNSPECIFIED','HARM_CATEGORY_HARASSMENT','HARM_CATEGORY_HATE_SPEECH','HARM_CATEGORY_SEXUALLY_EXPLICIT','HARM_CATEGORY_DANGEROUS_CONTENT','HARM_CATEGORY_CIVIC_INTEGRITY',]""" Safety settings category. """threshold:Literal['HARM_BLOCK_THRESHOLD_UNSPECIFIED','BLOCK_LOW_AND_ABOVE','BLOCK_MEDIUM_AND_ABOVE','BLOCK_ONLY_HIGH','BLOCK_NONE','OFF',]""" Safety settings threshold. """