Skip to content

Commit

Permalink
remove duplicate code
Browse files Browse the repository at this point in the history
  • Loading branch information
raspawar committed Jun 7, 2024
1 parent c97a31e commit 5eba3a6
Showing 1 changed file with 18 additions and 28 deletions.
46 changes: 18 additions & 28 deletions libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ def _validate_model(cls, values: Dict[str, Any]) -> Dict[str, Any]:
)
return values

def _add_authorization(self) -> Dict:
return {
"Authorization": f"Bearer {self.api_key.get_secret_value()}"
if self.api_key
else None,
}

@property
def available_models(self) -> list[Model]:
"""List the available models that can be invoked."""
Expand Down Expand Up @@ -187,16 +194,11 @@ def _post(
"""Method for posting to the AI Foundation Model Function API."""
self.last_inputs = {
"url": invoke_url,
# "headers": self.headers["call"],
"json": self.payload_fn(payload),
"stream": False,
}
headers = {
"Authorization": f"Bearer {self.api_key.get_secret_value()}"
if self.api_key
else None,
**self.headers["call"],
}
headers = self._add_authorization()
headers.update(**self.headers["call"])
session = self.get_session_fn()
self.last_response = response = session.post(
headers=headers, **self.last_inputs
Expand All @@ -212,17 +214,13 @@ def _get(
"""Method for getting from the AI Foundation Model Function API."""
self.last_inputs = {
"url": invoke_url,
# "headers": self.headers["call"],
"stream": False,
}
if payload:
self.last_inputs["json"] = self.payload_fn(payload)
headers = {
"Authorization": f"Bearer {self.api_key.get_secret_value()}"
if self.api_key
else None,
**self.headers["call"],
}

headers = self._add_authorization()
headers.update(**self.headers["call"])
session = self.get_session_fn()
self.last_response = response = session.get(headers=headers, **self.last_inputs)
self._try_raise(response)
Expand Down Expand Up @@ -442,16 +440,12 @@ def get_req_stream(
payload = {**payload, "stream": True}
self.last_inputs = {
"url": invoke_url,
# "headers": self.headers["stream"],
"json": self.payload_fn(payload),
"stream": True,
}
headers = {
"Authorization": f"Bearer {self.api_key.get_secret_value()}"
if self.api_key
else None,
**self.headers["stream"],
}

headers = self._add_authorization()
headers.update(**self.headers["stream"])
response = self.get_session_fn().post(headers=headers, **self.last_inputs)
self._try_raise(response)
call = self.copy()
Expand Down Expand Up @@ -483,15 +477,11 @@ async def get_req_astream(
payload = {**payload, "stream": True}
self.last_inputs = {
"url": invoke_url,
# "headers": self.headers["stream"],
"json": self.payload_fn(payload),
}
headers = {
"Authorization": f"Bearer {self.api_key.get_secret_value()}"
if self.api_key
else None,
**self.headers["stream"],
}

headers = self._add_authorization()
headers.update(**self.headers["stream"])
async with self.get_asession_fn() as session:
async with session.post(headers=headers, **self.last_inputs) as response:
self._try_raise(response)
Expand Down

0 comments on commit 5eba3a6

Please sign in to comment.