Skip to content

Commit

Permalink
feat(artifact): implement search chunks and sources
Browse files Browse the repository at this point in the history
  • Loading branch information
Yougigun committed Nov 29, 2024
1 parent 4ed0edd commit 4b490d4
Show file tree
Hide file tree
Showing 8 changed files with 507 additions and 28 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ require (
github.com/grpc-ecosystem/go-grpc-middleware v1.3.0
github.com/grpc-ecosystem/grpc-gateway/v2 v2.19.1
github.com/influxdata/influxdb-client-go/v2 v2.12.3
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20241125163328-c29704e47ba4
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20241129082755-59b3c0c34fe0
github.com/instill-ai/usage-client v0.3.0-alpha.0.20240319060111-4a3a39f2fd61
github.com/instill-ai/x v0.3.0-alpha.0.20231219052200-6230a89e386c
github.com/knadh/koanf v1.5.0
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,8 @@ github.com/influxdata/influxdb-client-go/v2 v2.12.3 h1:28nRlNMRIV4QbtIUvxhWqaxn0
github.com/influxdata/influxdb-client-go/v2 v2.12.3/go.mod h1:IrrLUbCjjfkmRuaCiGQg4m2GbkaeJDcuWoxiWdQEbA0=
github.com/influxdata/line-protocol v0.0.0-20200327222509-2487e7298839 h1:W9WBk7wlPfJLvMCdtV4zPulc4uCPrlywQOmbFOhgQNU=
github.com/influxdata/line-protocol v0.0.0-20200327222509-2487e7298839/go.mod h1:xaLFMmpvUxqXtVkUJfg9QmT88cDaCJ3ZKgdZ78oO8Qo=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20241125163328-c29704e47ba4 h1:k8X9gMiCwHWShB1FITaWwmlzthFnor1Jj0tSaFG+9x8=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20241125163328-c29704e47ba4/go.mod h1:rf0UY7VpEgpaLudYEcjx5rnbuwlBaaLyD4FQmWLtgAY=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20241129082755-59b3c0c34fe0 h1:Fok/s7GQoNMUA++1WbDdiZ6Ut8AXSuWNkTU2Q/0G9QA=
github.com/instill-ai/protogen-go v0.3.3-alpha.0.20241129082755-59b3c0c34fe0/go.mod h1:rf0UY7VpEgpaLudYEcjx5rnbuwlBaaLyD4FQmWLtgAY=
github.com/instill-ai/usage-client v0.3.0-alpha.0.20240319060111-4a3a39f2fd61 h1:smPTvmXDhn/QC7y/TPXyMTqbbRd0gvzmFgWBChwTfhE=
github.com/instill-ai/usage-client v0.3.0-alpha.0.20240319060111-4a3a39f2fd61/go.mod h1:/TAHs4ybuylk5icuy+MQtHRc4XUnIyXzeNKxX9qDFhw=
github.com/instill-ai/x v0.3.0-alpha.0.20231219052200-6230a89e386c h1:a2RVkpIV2QcrGnSHAou+t/L+vBsaIfFvk5inVg5Uh4s=
Expand Down
164 changes: 146 additions & 18 deletions pkg/handler/chunks.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,19 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"
)

// convertToProtoChunk
func convertToProtoChunk(chunk repository.TextChunk) *artifactpb.Chunk {
return &artifactpb.Chunk{
ChunkUid: chunk.UID.String(),
Retrievable: chunk.Retrievable,
StartPos: uint32(chunk.StartPos),
EndPos: uint32(chunk.EndPos),
Tokens: uint32(chunk.Tokens),
CreateTime: timestamppb.New(*chunk.CreateTime),
OriginalFileUid: chunk.KbFileUID.String(),
}
}

func (ph *PublicHandler) ListChunks(ctx context.Context, req *artifactpb.ListChunksRequest) (*artifactpb.ListChunksResponse, error) {
log, _ := logger.GetZapLogger(ctx)
authUID, err := getUserUIDFromContext(ctx)
Expand Down Expand Up @@ -73,22 +86,77 @@ func (ph *PublicHandler) ListChunks(ctx context.Context, req *artifactpb.ListChu

res := make([]*artifactpb.Chunk, 0, len(chunks))
for _, chunk := range chunks {
res = append(res, &artifactpb.Chunk{
ChunkUid: chunk.UID.String(),
Retrievable: chunk.Retrievable,
StartPos: uint32(chunk.StartPos),
EndPos: uint32(chunk.EndPos),
Tokens: uint32(chunk.Tokens),
CreateTime: timestamppb.New(*chunk.CreateTime),
OriginalFileUid: kbf.UID.String(),
})
res = append(res, convertToProtoChunk(chunk))
}

return &artifactpb.ListChunksResponse{
Chunks: res,
}, nil
}

func (ph *PublicHandler) SearchChunks(ctx context.Context, req *artifactpb.SearchChunksRequest) (*artifactpb.SearchChunksResponse, error) {
log, _ := logger.GetZapLogger(ctx)
_, err := getUserUIDFromContext(ctx)
if err != nil {
log.Error("failed to get user id from header", zap.Error(err))
return nil, fmt.Errorf("failed to get user id from header: %v. err: %w", err, customerror.ErrUnauthenticated)
}
// check if user can access the namespace
ns, err := ph.service.GetNamespaceAndCheckPermission(ctx, req.NamespaceId)
if err != nil {
log.Error("failed to get namespace and check permission", zap.Error(err))
return nil, fmt.Errorf("failed to get namespace and check permission: %w", err)
}

chunkUIDs := make([]uuid.UUID, 0, len(req.ChunkUids))
for _, chunkUID := range req.ChunkUids {
chunkUID, err := uuid.FromString(chunkUID)
if err != nil {
log.Error("failed to parse chunk uid", zap.Error(err))
return nil, fmt.Errorf("failed to parse chunk uid: %w", err)
}
chunkUIDs = append(chunkUIDs, chunkUID)
}
// check if the chunkUIs is more than 20
if len(chunkUIDs) > 25 {
log.Error("chunk uids is more than 20", zap.Int("chunk_uids_count", len(chunkUIDs)))
return nil, fmt.Errorf("chunk uids is more than 20")
}
chunks, err := ph.service.Repository.GetChunksByUIDs(ctx, chunkUIDs)
if err != nil {
log.Error("failed to get chunks by uids", zap.Error(err))
return nil, fmt.Errorf("failed to get chunks by uids: %w", err)
}

// get the kbUIDs from chunks
kbUIDs := make([]uuid.UUID, 0, len(chunks))
for _, chunk := range chunks {
kbUIDs = append(kbUIDs, chunk.KbUID)
}
// use kbUIDs to get the knowledge bases
knowledgeBases, err := ph.service.Repository.GetKnowledgeBasesByUIDs(ctx, kbUIDs)
if err != nil {
log.Error("failed to get knowledge bases by uids", zap.Error(err))
return nil, fmt.Errorf("failed to get knowledge bases by uids: %w", err)
}
// check if the chunks's knowledge base's owner(namespace uid) is the same as namespace uuid in path
for _, knowledgeBase := range knowledgeBases {
if knowledgeBase.Owner != ns.NsUID.String() {
log.Error("chunks's namespace is not the same as namespace in path", zap.String("namespace_id_in_path", ns.NsUID.String()), zap.String("namespace_id_in_chunks", knowledgeBase.Owner))
return nil, fmt.Errorf("chunks's namespace is not the same as namespace in path")
}
}

// populate the response
protoChunks := make([]*artifactpb.Chunk, 0, len(chunks))
for _, chunk := range chunks {
protoChunks = append(protoChunks, convertToProtoChunk(chunk))
}
return &artifactpb.SearchChunksResponse{
Chunks: protoChunks,
}, nil
}

func (ph *PublicHandler) UpdateChunk(ctx context.Context, req *artifactpb.UpdateChunkRequest) (*artifactpb.UpdateChunkResponse, error) {
log, _ := logger.GetZapLogger(ctx)
authUID, err := getUserUIDFromContext(ctx)
Expand Down Expand Up @@ -131,15 +199,7 @@ func (ph *PublicHandler) UpdateChunk(ctx context.Context, req *artifactpb.Update

return &artifactpb.UpdateChunkResponse{
// Populate the response fields appropriately
Chunk: &artifactpb.Chunk{
ChunkUid: chunk.UID.String(),
Retrievable: chunk.Retrievable,
StartPos: uint32(chunk.StartPos),
EndPos: uint32(chunk.EndPos),
Tokens: uint32(chunk.Tokens),
CreateTime: timestamppb.New(*chunk.CreateTime),
// OriginalFileUid: chunk.FileUID.String(),
},
Chunk: convertToProtoChunk(*chunk),
}, nil
}

Expand Down Expand Up @@ -189,3 +249,71 @@ func (ph *PublicHandler) GetSourceFile(ctx context.Context, req *artifactpb.GetS
},
}, nil
}

// SearchSourceFiles
func (ph *PublicHandler) SearchSourceFiles(ctx context.Context, req *artifactpb.SearchSourceFilesRequest) (*artifactpb.SearchSourceFilesResponse, error) {
log, _ := logger.GetZapLogger(ctx)
authUID, err := getUserUIDFromContext(ctx)
if err != nil {
log.Error("failed to get user id from header", zap.Error(err))
return nil, fmt.Errorf("failed to get user id from header: %v. err: %w", err, customerror.ErrUnauthenticated)
}

// Check if user can access the namespace
_, err = ph.service.GetNamespaceAndCheckPermission(ctx, req.NamespaceId)
if err != nil {
log.Error("failed to get namespace and check permission", zap.Error(err))
return nil, fmt.Errorf("failed to get namespace and check permission: %w", err)
}

fileUIDs := make([]uuid.UUID, 0, len(req.FileUids))
for _, fileUID := range req.FileUids {
uid, err := uuid.FromString(fileUID)
if err != nil {
log.Error("failed to parse file uid", zap.Error(err))
return nil, fmt.Errorf("failed to parse file uid: %v. err: %w", err, customerror.ErrInvalidArgument)
}
fileUIDs = append(fileUIDs, uid)
}

sources := make([]*artifactpb.SourceFile, 0, len(fileUIDs))
for _, fileUID := range fileUIDs {
source, err := ph.service.Repository.GetTruthSourceByFileUID(ctx, fileUID)
if err != nil {
log.Error("failed to get truth source by file uid", zap.Error(err))
return nil, fmt.Errorf("failed to get truth source by file uid. err: %w", err)
}

// ACL check for each source file
granted, err := ph.service.ACLClient.CheckPermission(ctx, "knowledgebase", source.KbUID, "reader")
if err != nil {
log.Error("failed to check permission", zap.Error(err))
return nil, fmt.Errorf("failed to check permission. err: %w", err)
}
if !granted {
log.Error("no permission to access source file",
zap.String("user_uid", authUID),
zap.String("kb_uid", source.KbUID.String()))
return nil, fmt.Errorf("no permission to access source file. err: %w. user_uid: %s. kb_uid: %s", customerror.ErrNoPermission, authUID, source.KbUID.String())
}

// Get file content from MinIO
content, err := ph.service.MinIO.GetFile(ctx, minio.KnowledgeBaseBucketName, source.Dest)
if err != nil {
log.Error("failed to get file from minio", zap.Error(err))
continue
}

sources = append(sources, &artifactpb.SourceFile{
OriginalFileUid: source.OriginalFileUID.String(),
OriginalFileName: source.OriginalFileName,
Content: string(content),
CreateTime: timestamppb.New(source.CreateTime),
UpdateTime: timestamppb.New(source.UpdateTime),
})
}

return &artifactpb.SearchSourceFilesResponse{
SourceFiles: sources,
}, nil
}
2 changes: 1 addition & 1 deletion pkg/handler/knowledgebase.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func (ph *PublicHandler) CreateCatalog(ctx context.Context, req *artifactpb.Crea
ns, err := ph.service.GetNamespaceByNsID(ctx, req.GetNamespaceId())
if err != nil {
log.Error(
"failed to check namespace permission",
"failed to get namespace",
zap.Error(err),
zap.String("owner_id(ns_id)", req.GetNamespaceId()),
zap.String("auth_uid", authUID))
Expand Down
Loading

0 comments on commit 4b490d4

Please sign in to comment.