Skip to content

Commit

Permalink
Wfh/share dataset (#261)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Oct 19, 2023
1 parent 81100a2 commit 2ac0e62
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 1 deletion.
86 changes: 86 additions & 0 deletions js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { AsyncCaller, AsyncCallerParams } from "./utils/async_caller.js";
import {
DataType,
Dataset,
DatasetShareSchema,
Example,
ExampleCreate,
ExampleUpdate,
Expand Down Expand Up @@ -584,6 +585,91 @@ export class Client {
return runs as Run[];
}

public async readDatasetSharedSchema(
datasetId?: string,
datasetName?: string
): Promise<DatasetShareSchema> {
if (!datasetId && !datasetName) {
throw new Error("Either datasetId or datasetName must be given");
}
if (!datasetId) {
const dataset = await this.readDataset({ datasetName });
datasetId = dataset.id;
}
const response = await this.caller.call(
fetch,
`${this.apiUrl}/datasets/${datasetId}/share`,
{
method: "GET",
headers: this.headers,
signal: AbortSignal.timeout(this.timeout_ms),
}
);
const shareSchema = await response.json();
shareSchema.url = `${this.getHostUrl()}/public/${
shareSchema.share_token
}/d`;
return shareSchema as DatasetShareSchema;
}

public async shareDataset(
datasetId?: string,
datasetName?: string
): Promise<DatasetShareSchema> {
if (!datasetId && !datasetName) {
throw new Error("Either datasetId or datasetName must be given");
}
if (!datasetId) {
const dataset = await this.readDataset({ datasetName });
datasetId = dataset.id;
}
const data = {
dataset_id: datasetId,
};
const response = await this.caller.call(
fetch,
`${this.apiUrl}/datasets/${datasetId}/share`,
{
method: "PUT",
headers: this.headers,
body: JSON.stringify(data),
signal: AbortSignal.timeout(this.timeout_ms),
}
);
const shareSchema = await response.json();
shareSchema.url = `${this.getHostUrl()}/public/${
shareSchema.share_token
}/d`;
return shareSchema as DatasetShareSchema;
}

public async unshareDataset(datasetId: string): Promise<void> {
const response = await this.caller.call(
fetch,
`${this.apiUrl}/datasets/${datasetId}/share`,
{
method: "DELETE",
headers: this.headers,
signal: AbortSignal.timeout(this.timeout_ms),
}
);
await raiseForStatus(response, "unshare dataset");
}

public async readSharedDataset(shareToken: string): Promise<Dataset> {
const response = await this.caller.call(
fetch,
`${this.apiUrl}/public/${shareToken}/datasets`,
{
method: "GET",
headers: this.headers,
signal: AbortSignal.timeout(this.timeout_ms),
}
);
const dataset = await response.json();
return dataset as Dataset;
}

public async createProject({
projectName,
projectExtra,
Expand Down
8 changes: 8 additions & 0 deletions js/src/schemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,14 @@ export interface Dataset extends BaseDataset {
id: string;
created_at: string;
modified_at: string;
example_count?: number;
session_count?: number;
last_session_start_time?: number;
}
export interface DatasetShareSchema {
dataset_id: string;
share_token: string;
url: string;
}

export interface FeedbackSourceBase {
Expand Down
85 changes: 85 additions & 0 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,91 @@ def list_shared_runs(
ls_schemas.Run(**run, _host_url=self._host_url) for run in response.json()
]

def read_dataset_shared_schema(
self,
dataset_id: Optional[ID_TYPE] = None,
*,
dataset_name: Optional[str] = None,
) -> ls_schemas.DatasetShareSchema:
if dataset_id is None and dataset_name is None:
raise ValueError("Either dataset_id or dataset_name must be given")
if dataset_id is None:
dataset_id = self.read_dataset(dataset_name=dataset_name).id
response = self.session.get(
f"{self.api_url}/datasets/{dataset_id}/share",
headers=self._headers,
)
ls_utils.raise_for_status_with_text(response)
d = response.json()
return cast(
ls_schemas.DatasetShareSchema,
{**d, "url": f"{self._host_url}/public/{d['share_token']}/d"},
)

def share_dataset(
self,
dataset_id: Optional[ID_TYPE] = None,
*,
dataset_name: Optional[str] = None,
) -> ls_schemas.DatasetShareSchema:
"""Get a share link for a dataset."""
if dataset_id is None and dataset_name is None:
raise ValueError("Either dataset_id or dataset_name must be given")
if dataset_id is None:
dataset_id = self.read_dataset(dataset_name=dataset_name).id
data = {
"dataset_id": str(dataset_id),
}
response = self.session.put(
f"{self.api_url}/datasets/{dataset_id}/share",
headers=self._headers,
json=data,
)
ls_utils.raise_for_status_with_text(response)
d: dict = response.json()
return cast(
ls_schemas.DatasetShareSchema,
{**d, "url": f"{self._host_url}/public/{d['share_token']}/d"},
)

def unshare_dataset(self, dataset_id: ID_TYPE) -> None:
"""Delete share link for a dataset."""
response = self.session.delete(
f"{self.api_url}/datasets/{dataset_id}/share",
headers=self._headers,
)
ls_utils.raise_for_status_with_text(response)

def read_shared_dataset(
self,
share_token: str,
) -> ls_schemas.Dataset:
"""Get shared datasets."""
response = self.session.get(
f"{self.api_url}/public/{share_token}/datasets",
headers=self._headers,
)
ls_utils.raise_for_status_with_text(response)
return ls_schemas.Dataset(**response.json(), _host_url=self._host_url)

def list_shared_examples(
self, share_token: str, *, example_ids: Optional[List[ID_TYPE]] = None
) -> List[ls_schemas.Example]:
"""Get shared examples."""
params = {}
if example_ids is not None:
params["id"] = [str(id) for id in example_ids]
response = self.session.get(
f"{self.api_url}/public/{share_token}/examples",
headers=self._headers,
params=params,
)
ls_utils.raise_for_status_with_text(response)
return [
ls_schemas.Example(**dataset, _host_url=self._host_url)
for dataset in response.json()
]

def create_project(
self,
project_name: str,
Expand Down
20 changes: 19 additions & 1 deletion python/langsmith/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@

from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union, runtime_checkable
from typing import (
Any,
Dict,
List,
Optional,
Protocol,
TypedDict,
Union,
runtime_checkable,
)
from uuid import UUID

try:
Expand Down Expand Up @@ -101,6 +110,9 @@ class Dataset(DatasetBase):
id: UUID
created_at: datetime
modified_at: Optional[datetime] = Field(default=None)
example_count: Optional[int] = None
session_count: Optional[int] = None
last_session_start_time: Optional[datetime] = None
_host_url: Optional[str] = PrivateAttr(default=None)

def __init__(self, _host_url: Optional[str] = None, **kwargs: Any) -> None:
Expand Down Expand Up @@ -381,3 +393,9 @@ class BaseMessageLike(Protocol):
@property
def type(self) -> str:
"""Type of the Message, used for serialization."""


class DatasetShareSchema(TypedDict, total=False):
dataset_id: UUID
share_token: UUID
url: str

0 comments on commit 2ac0e62

Please sign in to comment.