diff --git a/internal/core/src/storage/MinioChunkManager.cpp b/internal/core/src/storage/MinioChunkManager.cpp index badddb5fa7977..a77d4d4bba0b5 100644 --- a/internal/core/src/storage/MinioChunkManager.cpp +++ b/internal/core/src/storage/MinioChunkManager.cpp @@ -347,7 +347,7 @@ MinioChunkManager::Remove(const std::string& filepath) { std::vector MinioChunkManager::ListWithPrefix(const std::string& filepath) { - return ListObjects(default_bucket_name_.c_str(), filepath.c_str()); + return ListObjects(default_bucket_name_, filepath); } uint64_t @@ -393,7 +393,7 @@ MinioChunkManager::ListBuckets() { ThrowS3Error("ListBuckets", err, "params"); } for (auto&& b : outcome.GetResult().GetBuckets()) { - buckets.emplace_back(b.GetName().c_str()); + buckets.emplace_back(b.GetName()); } return buckets; } @@ -623,7 +623,7 @@ MinioChunkManager::ListObjects(const std::string& bucket_name, } auto objects = outcome.GetResult().GetContents(); for (auto& obj : objects) { - objects_vec.emplace_back(obj.GetKey().c_str()); + objects_vec.emplace_back(obj.GetKey()); } return objects_vec; } diff --git a/internal/distributed/proxy/httpserver/constant.go b/internal/distributed/proxy/httpserver/constant.go index 03e33c41b859c..93e56393aa89d 100644 --- a/internal/distributed/proxy/httpserver/constant.go +++ b/internal/distributed/proxy/httpserver/constant.go @@ -19,15 +19,15 @@ const ( EnableAutoID = true DisableAutoID = false - HTTPCollectionName = "collectionName" - HTTPDbName = "dbName" - DefaultDbName = "default" - DefaultIndexName = "vector_idx" - DefaultOutputFields = "*" - - HTTPReturnCode = "code" - HTTPReturnMessage = "message" - HTTPReturnData = "data" + HTTPCollectionName = "collectionName" + HTTPDbName = "dbName" + DefaultDbName = "default" + DefaultIndexName = "vector_idx" + DefaultOutputFields = "*" + HTTPHeaderAllowInt64 = "Accept-Type-Allow-Int64" + HTTPReturnCode = "code" + HTTPReturnMessage = "message" + HTTPReturnData = "data" HTTPReturnFieldName = "name" HTTPReturnFieldType = "type" diff --git a/internal/distributed/proxy/httpserver/handler_v1.go b/internal/distributed/proxy/httpserver/handler_v1.go index 431e0ee634238..790d942790efb 100644 --- a/internal/distributed/proxy/httpserver/handler_v1.go +++ b/internal/distributed/proxy/httpserver/handler_v1.go @@ -54,7 +54,10 @@ func (h *Handlers) checkDatabase(ctx context.Context, c *gin.Context, dbName str return true } } - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrDatabaseNotFound), HTTPReturnMessage: merr.ErrDatabaseNotFound.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrDatabaseNotFound), + HTTPReturnMessage: merr.ErrDatabaseNotFound.Error() + ", database: " + dbName, + }) return false } @@ -152,12 +155,18 @@ func (h *Handlers) createCollection(c *gin.Context) { } if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of create collection is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), + HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), + }) return } if httpReq.CollectionName == "" || httpReq.Dimension == 0 { log.Warn("high level restful api, create collection require parameters: [collectionName, dimension], but miss", zap.Any("request", httpReq)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), + HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", required parameters: [collectionName, dimension]", + }) return } schema, err := proto.Marshal(&schemapb.CollectionSchema{ @@ -189,7 +198,10 @@ func (h *Handlers) createCollection(c *gin.Context) { }) if err != nil { log.Warn("high level restful api, marshal collection schema fail", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMarshalCollectionSchema), HTTPReturnMessage: merr.ErrMarshalCollectionSchema.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrMarshalCollectionSchema), + HTTPReturnMessage: merr.ErrMarshalCollectionSchema.Error() + ", error: " + err.Error(), + }) return } req := milvuspb.CreateCollectionRequest{ @@ -248,7 +260,10 @@ func (h *Handlers) getCollectionDetails(c *gin.Context) { collectionName := c.Query(HTTPCollectionName) if collectionName == "" { log.Warn("high level restful api, desc collection require parameter: [collectionName], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), + HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", required parameters: [collectionName]", + }) return } dbName := c.DefaultQuery(HTTPDbName, DefaultDbName) @@ -320,12 +335,18 @@ func (h *Handlers) dropCollection(c *gin.Context) { } if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of drop collection is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), + HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), + }) return } if httpReq.CollectionName == "" { log.Warn("high level restful api, drop collection require parameter: [collectionName], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), + HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", required parameters: [collectionName]", + }) return } req := milvuspb.DropCollectionRequest{ @@ -345,7 +366,10 @@ func (h *Handlers) dropCollection(c *gin.Context) { return } if !has { - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrCollectionNotFound), HTTPReturnMessage: merr.ErrCollectionNotFound.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrCollectionNotFound), + HTTPReturnMessage: merr.ErrCollectionNotFound.Error() + ", database: " + httpReq.DbName + ", collection: " + httpReq.CollectionName, + }) return } response, err := h.proxy.DropCollection(ctx, &req) @@ -367,12 +391,18 @@ func (h *Handlers) query(c *gin.Context) { } if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of query is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), + HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), + }) return } if httpReq.CollectionName == "" || httpReq.Filter == "" { log.Warn("high level restful api, query require parameter: [collectionName, filter], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), + HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", required parameters: [collectionName, filter]", + }) return } req := milvuspb.QueryRequest{ @@ -404,10 +434,14 @@ func (h *Handlers) query(c *gin.Context) { if err != nil { c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) } else { - outputData, err := buildQueryResp(int64(0), response.OutputFields, response.FieldsData, nil, nil) + allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) + outputData, err := buildQueryResp(int64(0), response.OutputFields, response.FieldsData, nil, nil, allowJS) if err != nil { log.Warn("high level restful api, fail to deal with query result", zap.Any("response", response), zap.Error(err)) - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error()}) + c.JSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult), + HTTPReturnMessage: merr.ErrInvalidSearchResult.Error() + ", error: " + err.Error(), + }) } else { c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData}) } @@ -421,12 +455,18 @@ func (h *Handlers) get(c *gin.Context) { } if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of get is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), + HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), + }) return } if httpReq.CollectionName == "" || httpReq.ID == nil { log.Warn("high level restful api, get require parameter: [collectionName, id], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), + HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", required parameters: [collectionName, id]", + }) return } req := milvuspb.QueryRequest{ @@ -450,7 +490,10 @@ func (h *Handlers) get(c *gin.Context) { body, _ := c.Get(gin.BodyBytesKey) filter, err := checkGetPrimaryKey(coll.Schema, gjson.Get(string(body.([]byte)), DefaultPrimaryFieldName)) if err != nil { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()}) + c.JSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), + HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error() + ", error: " + err.Error(), + }) return } req.Expr = filter @@ -461,13 +504,16 @@ func (h *Handlers) get(c *gin.Context) { if err != nil { c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(err), HTTPReturnMessage: err.Error()}) } else { - outputData, err := buildQueryResp(int64(0), response.OutputFields, response.FieldsData, nil, nil) + allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) + outputData, err := buildQueryResp(int64(0), response.OutputFields, response.FieldsData, nil, nil, allowJS) if err != nil { log.Warn("high level restful api, fail to deal with get result", zap.Any("response", response), zap.Error(err)) - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error()}) + c.JSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult), + HTTPReturnMessage: merr.ErrInvalidSearchResult.Error() + ", error: " + err.Error(), + }) } else { c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData}) - log.Error("get resultIS: ", zap.Any("res", outputData)) } } } @@ -478,12 +524,18 @@ func (h *Handlers) delete(c *gin.Context) { } if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of delete is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), + HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), + }) return } if httpReq.CollectionName == "" || (httpReq.ID == nil && httpReq.Filter == "") { log.Warn("high level restful api, delete require parameter: [collectionName, id/filter], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), + HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", required parameters: [collectionName, id/filter]", + }) return } req := milvuspb.DeleteRequest{ @@ -507,7 +559,10 @@ func (h *Handlers) delete(c *gin.Context) { body, _ := c.Get(gin.BodyBytesKey) filter, err := checkGetPrimaryKey(coll.Schema, gjson.Get(string(body.([]byte)), DefaultPrimaryFieldName)) if err != nil { - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()}) + c.JSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), + HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error() + ", error: " + err.Error(), + }) return } req.Expr = filter @@ -533,7 +588,10 @@ func (h *Handlers) insert(c *gin.Context) { } if err = c.ShouldBindBodyWith(&singleInsertReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of insert is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), + HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), + }) return } httpReq.DbName = singleInsertReq.DbName @@ -542,7 +600,10 @@ func (h *Handlers) insert(c *gin.Context) { } if httpReq.CollectionName == "" || httpReq.Data == nil { log.Warn("high level restful api, insert require parameter: [collectionName, data], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), + HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", required parameters: [collectionName, data]", + }) return } req := milvuspb.InsertRequest{ @@ -567,13 +628,19 @@ func (h *Handlers) insert(c *gin.Context) { err, httpReq.Data = checkAndSetData(string(body.([]byte)), coll) if err != nil { log.Warn("high level restful api, fail to deal with insert data", zap.Any("body", body), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), + HTTPReturnMessage: merr.ErrInvalidInsertData.Error() + ", error: " + err.Error(), + }) return } req.FieldsData, err = anyToColumns(httpReq.Data, coll.Schema) if err != nil { log.Warn("high level restful api, fail to deal with insert data", zap.Any("data", httpReq.Data), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), + HTTPReturnMessage: merr.ErrInvalidInsertData.Error() + ", error: " + err.Error(), + }) return } response, err := h.proxy.Insert(ctx, &req) @@ -585,11 +652,19 @@ func (h *Handlers) insert(c *gin.Context) { } else { switch response.IDs.GetIdField().(type) { case *schemapb.IDs_IntId: - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"insertCount": response.InsertCnt, "insertIds": response.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data}}) + allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) + if allowJS { + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"insertCount": response.InsertCnt, "insertIds": response.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data}}) + } else { + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"insertCount": response.InsertCnt, "insertIds": formatInt64(response.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data)}}) + } case *schemapb.IDs_StrId: c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"insertCount": response.InsertCnt, "insertIds": response.IDs.IdField.(*schemapb.IDs_StrId).StrId.Data}}) default: - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()}) + c.JSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), + HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error() + ", error: unsupported primary key data type", + }) } } } @@ -603,8 +678,11 @@ func (h *Handlers) upsert(c *gin.Context) { DbName: DefaultDbName, } if err = c.ShouldBindBodyWith(&singleUpsertReq, binding.JSON); err != nil { - log.Warn("high level restful api, the parameter of insert is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) + log.Warn("high level restful api, the parameter of upsert is incorrect", zap.Any("request", httpReq), zap.Error(err)) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), + HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), + }) return } httpReq.DbName = singleUpsertReq.DbName @@ -612,8 +690,11 @@ func (h *Handlers) upsert(c *gin.Context) { httpReq.Data = []map[string]interface{}{singleUpsertReq.Data} } if httpReq.CollectionName == "" || httpReq.Data == nil { - log.Warn("high level restful api, insert require parameter: [collectionName, data], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) + log.Warn("high level restful api, upsert require parameter: [collectionName, data], but miss") + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), + HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", required parameters: [collectionName, data]", + }) return } req := milvuspb.UpsertRequest{ @@ -642,14 +723,20 @@ func (h *Handlers) upsert(c *gin.Context) { body, _ := c.Get(gin.BodyBytesKey) err, httpReq.Data = checkAndSetData(string(body.([]byte)), coll) if err != nil { - log.Warn("high level restful api, fail to deal with insert data", zap.Any("body", body), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error()}) + log.Warn("high level restful api, fail to deal with upsert data", zap.Any("body", body), zap.Error(err)) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), + HTTPReturnMessage: merr.ErrInvalidInsertData.Error() + ", error: " + err.Error(), + }) return } req.FieldsData, err = anyToColumns(httpReq.Data, coll.Schema) if err != nil { - log.Warn("high level restful api, fail to deal with insert data", zap.Any("data", httpReq.Data), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), HTTPReturnMessage: merr.ErrInvalidInsertData.Error()}) + log.Warn("high level restful api, fail to deal with upsert data", zap.Any("data", httpReq.Data), zap.Error(err)) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrInvalidInsertData), + HTTPReturnMessage: merr.ErrInvalidInsertData.Error() + ", error: " + err.Error(), + }) return } response, err := h.proxy.Upsert(ctx, &req) @@ -661,11 +748,19 @@ func (h *Handlers) upsert(c *gin.Context) { } else { switch response.IDs.GetIdField().(type) { case *schemapb.IDs_IntId: - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"upsertCount": response.UpsertCnt, "upsertIds": response.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data}}) + allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) + if allowJS { + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"upsertCount": response.UpsertCnt, "upsertIds": response.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data}}) + } else { + c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"upsertCount": response.UpsertCnt, "upsertIds": formatInt64(response.IDs.IdField.(*schemapb.IDs_IntId).IntId.Data)}}) + } case *schemapb.IDs_StrId: c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: gin.H{"upsertCount": response.UpsertCnt, "upsertIds": response.IDs.IdField.(*schemapb.IDs_StrId).StrId.Data}}) default: - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error()}) + c.JSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrCheckPrimaryKey), + HTTPReturnMessage: merr.ErrCheckPrimaryKey.Error() + ", error: unsupported primary key data type", + }) } } } @@ -677,12 +772,18 @@ func (h *Handlers) search(c *gin.Context) { } if err := c.ShouldBindBodyWith(&httpReq, binding.JSON); err != nil { log.Warn("high level restful api, the parameter of search is incorrect", zap.Any("request", httpReq), zap.Error(err)) - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrIncorrectParameterFormat), + HTTPReturnMessage: merr.ErrIncorrectParameterFormat.Error() + ", error: " + err.Error(), + }) return } if httpReq.CollectionName == "" || httpReq.Vector == nil { log.Warn("high level restful api, search require parameter: [collectionName, vector], but miss") - c.AbortWithStatusJSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error()}) + c.AbortWithStatusJSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrMissingRequiredParameters), + HTTPReturnMessage: merr.ErrMissingRequiredParameters.Error() + ", required parameters: [collectionName, vector]", + }) return } params := map[string]interface{}{ // auto generated mapping @@ -724,10 +825,14 @@ func (h *Handlers) search(c *gin.Context) { if response.Results.TopK == int64(0) { c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: []interface{}{}}) } else { - outputData, err := buildQueryResp(response.Results.TopK, response.Results.OutputFields, response.Results.FieldsData, response.Results.Ids, response.Results.Scores) + allowJS, _ := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) + outputData, err := buildQueryResp(response.Results.TopK, response.Results.OutputFields, response.Results.FieldsData, response.Results.Ids, response.Results.Scores, allowJS) if err != nil { log.Warn("high level restful api, fail to deal with search result", zap.Any("result", response.Results), zap.Error(err)) - c.JSON(http.StatusOK, gin.H{HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult), HTTPReturnMessage: merr.ErrInvalidSearchResult.Error()}) + c.JSON(http.StatusOK, gin.H{ + HTTPReturnCode: merr.Code(merr.ErrInvalidSearchResult), + HTTPReturnMessage: merr.ErrInvalidSearchResult.Error() + ", error: " + err.Error(), + }) } else { c.JSON(http.StatusOK, gin.H{HTTPReturnCode: http.StatusOK, HTTPReturnData: outputData}) } diff --git a/internal/distributed/proxy/httpserver/handler_v1_test.go b/internal/distributed/proxy/httpserver/handler_v1_test.go index facbdb71832db..5d2d109f13ef4 100644 --- a/internal/distributed/proxy/httpserver/handler_v1_test.go +++ b/internal/distributed/proxy/httpserver/handler_v1_test.go @@ -6,6 +6,8 @@ import ( "fmt" "net/http" "net/http/httptest" + "strconv" + "strings" "testing" "github.com/cockroachdb/errors" @@ -51,7 +53,7 @@ var DefaultShowCollectionsResp = milvuspb.ShowCollectionsResponse{ var DefaultDescCollectionResp = milvuspb.DescribeCollectionResponse{ CollectionName: DefaultCollectionName, - Schema: generateCollectionSchema(false), + Schema: generateCollectionSchema(schemapb.DataType_Int64, false), ShardsNum: ShardNumDefault, Status: &StatusSuccess, } @@ -83,6 +85,18 @@ func versional(path string) string { func initHTTPServer(proxy types.ProxyComponent, needAuth bool) *gin.Engine { h := NewHandlers(proxy) ginHandler := gin.Default() + ginHandler.Use(func(c *gin.Context) { + _, err := strconv.ParseBool(c.Request.Header.Get(HTTPHeaderAllowInt64)) + if err != nil { + httpParams := ¶mtable.Get().HTTPCfg + if httpParams.AcceptTypeAllowInt64.GetAsBool() { + c.Request.Header.Set(HTTPHeaderAllowInt64, "true") + } else { + c.Request.Header.Set(HTTPHeaderAllowInt64, "false") + } + } + c.Next() + }) app := ginHandler.Group(URIPrefixV1, genAuthMiddleWare(needAuth)) NewHandlers(h.proxy).RegisterRoutesToV1(app) return ginHandler @@ -131,6 +145,11 @@ func PrintErr(err error) string { return Print(merr.Code(err), err.Error()) } +func CheckErrCode(errorStr string, err error) bool { + prefix := fmt.Sprintf("{\"%s\":%d,\"%s\":\"%s", HTTPReturnCode, merr.Code(err), HTTPReturnMessage, err.Error()) + return strings.HasPrefix(errorStr, prefix) +} + func TestVectorAuthenticate(t *testing.T) { paramtable.Init() @@ -146,8 +165,8 @@ func TestVectorAuthenticate(t *testing.T) { req := httptest.NewRequest(http.MethodGet, versional(VectorCollectionsPath), nil) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, http.StatusUnauthorized) - assert.Equal(t, w.Body.String(), PrintErr(merr.ErrNeedAuthenticate)) + assert.Equal(t, http.StatusUnauthorized, w.Code) + assert.Equal(t, true, CheckErrCode(w.Body.String(), merr.ErrNeedAuthenticate)) }) t.Run("username or password incorrect", func(t *testing.T) { @@ -155,8 +174,8 @@ func TestVectorAuthenticate(t *testing.T) { req.SetBasicAuth(util.UserRoot, util.UserRoot) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, http.StatusUnauthorized) - assert.Equal(t, w.Body.String(), PrintErr(merr.ErrNeedAuthenticate)) + assert.Equal(t, http.StatusUnauthorized, w.Code) + assert.Equal(t, true, CheckErrCode(w.Body.String(), merr.ErrNeedAuthenticate)) }) t.Run("root's password correct", func(t *testing.T) { @@ -164,8 +183,8 @@ func TestVectorAuthenticate(t *testing.T) { req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, http.StatusOK) - assert.Equal(t, w.Body.String(), "{\"code\":200,\"data\":[\""+DefaultCollectionName+"\"]}") + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "{\"code\":200,\"data\":[\""+DefaultCollectionName+"\"]}", w.Body.String()) }) t.Run("username and password both provided", func(t *testing.T) { @@ -173,8 +192,8 @@ func TestVectorAuthenticate(t *testing.T) { req.SetBasicAuth("test", util.UserRoot) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, http.StatusOK) - assert.Equal(t, w.Body.String(), "{\"code\":200,\"data\":[\""+DefaultCollectionName+"\"]}") + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "{\"code\":200,\"data\":[\""+DefaultCollectionName+"\"]}", w.Body.String()) }) } @@ -218,8 +237,12 @@ func TestVectorListCollection(t *testing.T) { req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, tt.exceptCode) - assert.Equal(t, w.Body.String(), tt.expectedBody) + assert.Equal(t, tt.exceptCode, w.Code) + if tt.expectedErr != nil { + assert.Equal(t, true, CheckErrCode(w.Body.String(), tt.expectedErr)) + } else { + assert.Equal(t, tt.expectedBody, w.Body.String()) + } }) } } @@ -229,6 +252,7 @@ type testCase struct { mp *mocks.MockProxy exceptCode int expectedBody string + expectedErr error } func TestVectorCollectionsDescribe(t *testing.T) { @@ -277,8 +301,12 @@ func TestVectorCollectionsDescribe(t *testing.T) { req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, tt.exceptCode) - assert.Equal(t, w.Body.String(), tt.expectedBody) + assert.Equal(t, tt.exceptCode, w.Code) + if tt.expectedErr != nil { + assert.Equal(t, true, CheckErrCode(w.Body.String(), tt.expectedErr)) + } else { + assert.Equal(t, tt.expectedBody, w.Body.String()) + } }) } t.Run("need collectionName", func(t *testing.T) { @@ -287,8 +315,8 @@ func TestVectorCollectionsDescribe(t *testing.T) { req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, 200) - assert.Equal(t, w.Body.String(), PrintErr(merr.ErrMissingRequiredParameters)) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, true, CheckErrCode(w.Body.String(), merr.ErrMissingRequiredParameters)) }) } @@ -356,8 +384,12 @@ func TestVectorCreateCollection(t *testing.T) { req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, tt.exceptCode) - assert.Equal(t, w.Body.String(), tt.expectedBody) + assert.Equal(t, tt.exceptCode, w.Code) + if tt.expectedErr != nil { + assert.Equal(t, true, CheckErrCode(w.Body.String(), tt.expectedErr)) + } else { + assert.Equal(t, tt.expectedBody, w.Body.String()) + } }) } } @@ -409,14 +441,19 @@ func TestVectorDropCollection(t *testing.T) { req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, tt.exceptCode) - assert.Equal(t, w.Body.String(), tt.expectedBody) + assert.Equal(t, tt.exceptCode, w.Code) + if tt.expectedErr != nil { + assert.Equal(t, true, CheckErrCode(w.Body.String(), tt.expectedErr)) + } else { + assert.Equal(t, tt.expectedBody, w.Body.String()) + } }) } } func TestQuery(t *testing.T) { paramtable.Init() + paramtable.Get().Save(proxy.Params.HTTPCfg.AcceptTypeAllowInt64.Key, "true") testCases := []testCase{} mp2 := mocks.NewMockProxy(t) @@ -445,6 +482,21 @@ func TestQuery(t *testing.T) { mp4 := mocks.NewMockProxy(t) mp4, _ = wrapWithDescribeColl(t, mp4, ReturnSuccess, 1, nil) mp4.EXPECT().Query(mock.Anything, mock.Anything).Return(&milvuspb.QueryResults{ + Status: &StatusSuccess, + FieldsData: newFieldData([]*schemapb.FieldData{}, 1000), + CollectionName: DefaultCollectionName, + OutputFields: []string{FieldBookID, FieldWordCount, FieldBookIntro}, + }, nil).Twice() + testCases = append(testCases, testCase{ + name: "query fail", + mp: mp4, + exceptCode: 200, + expectedErr: merr.ErrInvalidSearchResult, + }) + + mp5 := mocks.NewMockProxy(t) + mp5, _ = wrapWithDescribeColl(t, mp5, ReturnSuccess, 1, nil) + mp5.EXPECT().Query(mock.Anything, mock.Anything).Return(&milvuspb.QueryResults{ Status: &StatusSuccess, FieldsData: generateFieldData(), CollectionName: DefaultCollectionName, @@ -452,7 +504,7 @@ func TestQuery(t *testing.T) { }, nil).Twice() testCases = append(testCases, testCase{ name: "query success", - mp: mp4, + mp: mp5, exceptCode: 200, expectedBody: "{\"code\":200,\"data\":[{\"book_id\":1,\"book_intro\":[0.1,0.11],\"word_count\":1000},{\"book_id\":2,\"book_intro\":[0.2,0.22],\"word_count\":2000},{\"book_id\":3,\"book_intro\":[0.3,0.33],\"word_count\":3000}]}", }) @@ -465,16 +517,20 @@ func TestQuery(t *testing.T) { req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, tt.exceptCode) - assert.Equal(t, w.Body.String(), tt.expectedBody) + assert.Equal(t, tt.exceptCode, w.Code) + if tt.expectedErr != nil { + assert.Equal(t, true, CheckErrCode(w.Body.String(), tt.expectedErr)) + } else { + assert.Equal(t, tt.expectedBody, w.Body.String()) + } resp := map[string]interface{}{} err := json.Unmarshal(w.Body.Bytes(), &resp) - assert.Equal(t, err, nil) + assert.Equal(t, nil, err) if resp[HTTPReturnCode] == float64(200) { data := resp[HTTPReturnData].([]interface{}) rows := generateQueryResult64(false) for i, row := range data { - assert.Equal(t, compareRow64(row.(map[string]interface{}), rows[i]), true) + assert.Equal(t, true, compareRow64(row.(map[string]interface{}), rows[i])) } } } @@ -546,11 +602,15 @@ func TestDelete(t *testing.T) { req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, tt.exceptCode) - assert.Equal(t, w.Body.String(), tt.expectedBody) + assert.Equal(t, tt.exceptCode, w.Code) + if tt.expectedErr != nil { + assert.Equal(t, true, CheckErrCode(w.Body.String(), tt.expectedErr)) + } else { + assert.Equal(t, tt.expectedBody, w.Body.String()) + } resp := map[string]interface{}{} err := json.Unmarshal(w.Body.Bytes(), &resp) - assert.Equal(t, err, nil) + assert.Equal(t, nil, err) }) } } @@ -574,17 +634,18 @@ func TestDeleteForFilter(t *testing.T) { req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, 200) - assert.Equal(t, w.Body.String(), "{\"code\":200,\"data\":{}}") + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "{\"code\":200,\"data\":{}}", w.Body.String()) resp := map[string]interface{}{} err := json.Unmarshal(w.Body.Bytes(), &resp) - assert.Equal(t, err, nil) + assert.Equal(t, nil, err) }) } } func TestInsert(t *testing.T) { paramtable.Init() + paramtable.Get().Save(proxy.Params.HTTPCfg.AcceptTypeAllowInt64.Key, "true") testCases := []testCase{} _, testCases = wrapWithDescribeColl(t, nil, ReturnFail, 1, testCases) _, testCases = wrapWithDescribeColl(t, nil, ReturnWrongStatus, 1, testCases) @@ -618,17 +679,17 @@ func TestInsert(t *testing.T) { Status: &StatusSuccess, }, nil).Once() testCases = append(testCases, testCase{ - name: "id type invalid", - mp: mp4, - exceptCode: 200, - expectedBody: PrintErr(merr.ErrCheckPrimaryKey), + name: "id type invalid", + mp: mp4, + exceptCode: 200, + expectedErr: merr.ErrCheckPrimaryKey, }) mp5 := mocks.NewMockProxy(t) mp5, _ = wrapWithDescribeColl(t, mp5, ReturnSuccess, 1, nil) mp5.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ Status: &StatusSuccess, - IDs: getIntIds(), + IDs: genIds(schemapb.DataType_Int64), InsertCnt: 3, }, nil).Once() testCases = append(testCases, testCase{ @@ -642,7 +703,7 @@ func TestInsert(t *testing.T) { mp6, _ = wrapWithDescribeColl(t, mp6, ReturnSuccess, 1, nil) mp6.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ Status: &StatusSuccess, - IDs: getStrIds(), + IDs: genIds(schemapb.DataType_VarChar), InsertCnt: 3, }, nil).Once() testCases = append(testCases, testCase{ @@ -652,7 +713,7 @@ func TestInsert(t *testing.T) { expectedBody: "{\"code\":200,\"data\":{\"insertCount\":3,\"insertIds\":[\"1\",\"2\",\"3\"]}}", }) - rows := generateSearchResult() + rows := generateSearchResult(schemapb.DataType_Int64) data, _ := json.Marshal(map[string]interface{}{ HTTPCollectionName: DefaultCollectionName, HTTPReturnData: rows[0], @@ -665,11 +726,15 @@ func TestInsert(t *testing.T) { req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, tt.exceptCode) - assert.Equal(t, w.Body.String(), tt.expectedBody) + assert.Equal(t, tt.exceptCode, w.Code) + if tt.expectedErr != nil { + assert.Equal(t, true, CheckErrCode(w.Body.String(), tt.expectedErr)) + } else { + assert.Equal(t, tt.expectedBody, w.Body.String()) + } resp := map[string]interface{}{} err := json.Unmarshal(w.Body.Bytes(), &resp) - assert.Equal(t, err, nil) + assert.Equal(t, nil, err) }) } @@ -682,20 +747,21 @@ func TestInsert(t *testing.T) { req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, 200) - assert.Equal(t, w.Body.String(), PrintErr(merr.ErrInvalidInsertData)) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, true, CheckErrCode(w.Body.String(), merr.ErrInvalidInsertData)) resp := map[string]interface{}{} err := json.Unmarshal(w.Body.Bytes(), &resp) - assert.Equal(t, err, nil) + assert.Equal(t, nil, err) }) } func TestInsertForDataType(t *testing.T) { paramtable.Init() + paramtable.Get().Save(proxy.Params.HTTPCfg.AcceptTypeAllowInt64.Key, "true") schemas := map[string]*schemapb.CollectionSchema{ - "[success]kinds of data type": newCollectionSchema(generateCollectionSchema(false)), - "[success]use binary vector": newCollectionSchema(generateCollectionSchema(true)), - "[success]with dynamic field": withDynamicField(newCollectionSchema(generateCollectionSchema(false))), + "[success]kinds of data type": newCollectionSchema(generateCollectionSchema(schemapb.DataType_Int64, false)), + "[success]use binary vector": newCollectionSchema(generateCollectionSchema(schemapb.DataType_Int64, true)), + "[success]with dynamic field": withDynamicField(newCollectionSchema(generateCollectionSchema(schemapb.DataType_Int64, false))), } for name, schema := range schemas { t.Run(name, func(t *testing.T) { @@ -708,11 +774,11 @@ func TestInsertForDataType(t *testing.T) { }, nil).Once() mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ Status: &StatusSuccess, - IDs: getIntIds(), + IDs: genIds(schemapb.DataType_Int64), InsertCnt: 3, }, nil).Once() testEngine := initHTTPServer(mp, true) - rows := newSearchResult(generateSearchResult()) + rows := newSearchResult(generateSearchResult(schemapb.DataType_Int64)) data, _ := json.Marshal(map[string]interface{}{ HTTPCollectionName: DefaultCollectionName, HTTPReturnData: rows, @@ -722,12 +788,12 @@ func TestInsertForDataType(t *testing.T) { req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, 200) - assert.Equal(t, w.Body.String(), "{\"code\":200,\"data\":{\"insertCount\":3,\"insertIds\":[1,2,3]}}") + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "{\"code\":200,\"data\":{\"insertCount\":3,\"insertIds\":[1,2,3]}}", w.Body.String()) }) } schemas = map[string]*schemapb.CollectionSchema{ - "with unsupport field type": withUnsupportField(newCollectionSchema(generateCollectionSchema(false))), + "with unsupport field type": withUnsupportField(newCollectionSchema(generateCollectionSchema(schemapb.DataType_Int64, false))), } for name, schema := range schemas { t.Run(name, func(t *testing.T) { @@ -739,7 +805,239 @@ func TestInsertForDataType(t *testing.T) { Status: &StatusSuccess, }, nil).Once() testEngine := initHTTPServer(mp, true) - rows := newSearchResult(generateSearchResult()) + rows := newSearchResult(generateSearchResult(schemapb.DataType_Int64)) + data, _ := json.Marshal(map[string]interface{}{ + HTTPCollectionName: DefaultCollectionName, + HTTPReturnData: rows, + }) + bodyReader := bytes.NewReader(data) + req := httptest.NewRequest(http.MethodPost, versional(VectorInsertPath), bodyReader) + req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, true, CheckErrCode(w.Body.String(), merr.ErrInvalidInsertData)) + }) + } +} + +func TestReturnInt64(t *testing.T) { + paramtable.Init() + paramtable.Get().Save(proxy.Params.HTTPCfg.AcceptTypeAllowInt64.Key, "false") + schemas := []schemapb.DataType{ + schemapb.DataType_Int64, + schemapb.DataType_VarChar, + } + idStrs := map[schemapb.DataType]string{ + schemapb.DataType_Int64: "1,2,3", + schemapb.DataType_VarChar: "\"1\",\"2\",\"3\"", + } + for _, dataType := range schemas { + t.Run("[insert]httpCfg.allow: false", func(t *testing.T) { + schema := newCollectionSchema(generateCollectionSchema(dataType, false)) + mp := mocks.NewMockProxy(t) + mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionName: DefaultCollectionName, + Schema: schema, + ShardsNum: ShardNumDefault, + Status: &StatusSuccess, + }, nil).Once() + mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ + Status: &StatusSuccess, + IDs: genIds(dataType), + InsertCnt: 3, + }, nil).Once() + testEngine := initHTTPServer(mp, true) + rows := newSearchResult(generateSearchResult(dataType)) + data, _ := json.Marshal(map[string]interface{}{ + HTTPCollectionName: DefaultCollectionName, + HTTPReturnData: rows, + }) + bodyReader := bytes.NewReader(data) + req := httptest.NewRequest(http.MethodPost, versional(VectorInsertPath), bodyReader) + req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "{\"code\":200,\"data\":{\"insertCount\":3,\"insertIds\":[\"1\",\"2\",\"3\"]}}", w.Body.String()) + }) + } + + for _, dataType := range schemas { + t.Run("[upsert]httpCfg.allow: false", func(t *testing.T) { + schema := newCollectionSchema(generateCollectionSchema(dataType, false)) + mp := mocks.NewMockProxy(t) + mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionName: DefaultCollectionName, + Schema: schema, + ShardsNum: ShardNumDefault, + Status: &StatusSuccess, + }, nil).Once() + mp.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ + Status: &StatusSuccess, + IDs: genIds(dataType), + UpsertCnt: 3, + }, nil).Once() + testEngine := initHTTPServer(mp, true) + rows := newSearchResult(generateSearchResult(dataType)) + data, _ := json.Marshal(map[string]interface{}{ + HTTPCollectionName: DefaultCollectionName, + HTTPReturnData: rows, + }) + bodyReader := bytes.NewReader(data) + req := httptest.NewRequest(http.MethodPost, versional(VectorUpsertPath), bodyReader) + req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "{\"code\":200,\"data\":{\"upsertCount\":3,\"upsertIds\":[\"1\",\"2\",\"3\"]}}", w.Body.String()) + }) + } + + for _, dataType := range schemas { + t.Run("[insert]httpCfg.allow: false, Accept-Type-Allow-Int64: true", func(t *testing.T) { + schema := newCollectionSchema(generateCollectionSchema(dataType, false)) + mp := mocks.NewMockProxy(t) + mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionName: DefaultCollectionName, + Schema: schema, + ShardsNum: ShardNumDefault, + Status: &StatusSuccess, + }, nil).Once() + mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ + Status: &StatusSuccess, + IDs: genIds(dataType), + InsertCnt: 3, + }, nil).Once() + testEngine := initHTTPServer(mp, true) + rows := newSearchResult(generateSearchResult(dataType)) + data, _ := json.Marshal(map[string]interface{}{ + HTTPCollectionName: DefaultCollectionName, + HTTPReturnData: rows, + }) + bodyReader := bytes.NewReader(data) + req := httptest.NewRequest(http.MethodPost, versional(VectorInsertPath), bodyReader) + req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) + req.Header.Set(HTTPHeaderAllowInt64, "true") + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "{\"code\":200,\"data\":{\"insertCount\":3,\"insertIds\":["+idStrs[dataType]+"]}}", w.Body.String()) + }) + } + + for _, dataType := range schemas { + t.Run("[upsert]httpCfg.allow: false, Accept-Type-Allow-Int64: true", func(t *testing.T) { + schema := newCollectionSchema(generateCollectionSchema(dataType, false)) + mp := mocks.NewMockProxy(t) + mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionName: DefaultCollectionName, + Schema: schema, + ShardsNum: ShardNumDefault, + Status: &StatusSuccess, + }, nil).Once() + mp.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ + Status: &StatusSuccess, + IDs: genIds(dataType), + UpsertCnt: 3, + }, nil).Once() + testEngine := initHTTPServer(mp, true) + rows := newSearchResult(generateSearchResult(dataType)) + data, _ := json.Marshal(map[string]interface{}{ + HTTPCollectionName: DefaultCollectionName, + HTTPReturnData: rows, + }) + bodyReader := bytes.NewReader(data) + req := httptest.NewRequest(http.MethodPost, versional(VectorUpsertPath), bodyReader) + req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) + req.Header.Set(HTTPHeaderAllowInt64, "true") + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "{\"code\":200,\"data\":{\"upsertCount\":3,\"upsertIds\":["+idStrs[dataType]+"]}}", w.Body.String()) + }) + } + + paramtable.Get().Save(proxy.Params.HTTPCfg.AcceptTypeAllowInt64.Key, "true") + for _, dataType := range schemas { + t.Run("[insert]httpCfg.allow: true", func(t *testing.T) { + schema := newCollectionSchema(generateCollectionSchema(dataType, false)) + mp := mocks.NewMockProxy(t) + mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionName: DefaultCollectionName, + Schema: schema, + ShardsNum: ShardNumDefault, + Status: &StatusSuccess, + }, nil).Once() + mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ + Status: &StatusSuccess, + IDs: genIds(dataType), + InsertCnt: 3, + }, nil).Once() + testEngine := initHTTPServer(mp, true) + rows := newSearchResult(generateSearchResult(dataType)) + data, _ := json.Marshal(map[string]interface{}{ + HTTPCollectionName: DefaultCollectionName, + HTTPReturnData: rows, + }) + bodyReader := bytes.NewReader(data) + req := httptest.NewRequest(http.MethodPost, versional(VectorInsertPath), bodyReader) + req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "{\"code\":200,\"data\":{\"insertCount\":3,\"insertIds\":["+idStrs[dataType]+"]}}", w.Body.String()) + }) + } + + for _, dataType := range schemas { + t.Run("[upsert]httpCfg.allow: true", func(t *testing.T) { + schema := newCollectionSchema(generateCollectionSchema(dataType, false)) + mp := mocks.NewMockProxy(t) + mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionName: DefaultCollectionName, + Schema: schema, + ShardsNum: ShardNumDefault, + Status: &StatusSuccess, + }, nil).Once() + mp.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ + Status: &StatusSuccess, + IDs: genIds(dataType), + UpsertCnt: 3, + }, nil).Once() + testEngine := initHTTPServer(mp, true) + rows := newSearchResult(generateSearchResult(dataType)) + data, _ := json.Marshal(map[string]interface{}{ + HTTPCollectionName: DefaultCollectionName, + HTTPReturnData: rows, + }) + bodyReader := bytes.NewReader(data) + req := httptest.NewRequest(http.MethodPost, versional(VectorUpsertPath), bodyReader) + req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "{\"code\":200,\"data\":{\"upsertCount\":3,\"upsertIds\":["+idStrs[dataType]+"]}}", w.Body.String()) + }) + } + + for _, dataType := range schemas { + t.Run("[insert]httpCfg.allow: true, Accept-Type-Allow-Int64: false", func(t *testing.T) { + schema := newCollectionSchema(generateCollectionSchema(dataType, false)) + mp := mocks.NewMockProxy(t) + mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionName: DefaultCollectionName, + Schema: schema, + ShardsNum: ShardNumDefault, + Status: &StatusSuccess, + }, nil).Once() + mp.EXPECT().Insert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ + Status: &StatusSuccess, + IDs: genIds(dataType), + InsertCnt: 3, + }, nil).Once() + testEngine := initHTTPServer(mp, true) + rows := newSearchResult(generateSearchResult(dataType)) data, _ := json.Marshal(map[string]interface{}{ HTTPCollectionName: DefaultCollectionName, HTTPReturnData: rows, @@ -747,16 +1045,50 @@ func TestInsertForDataType(t *testing.T) { bodyReader := bytes.NewReader(data) req := httptest.NewRequest(http.MethodPost, versional(VectorInsertPath), bodyReader) req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) + req.Header.Set(HTTPHeaderAllowInt64, "false") + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "{\"code\":200,\"data\":{\"insertCount\":3,\"insertIds\":[\"1\",\"2\",\"3\"]}}", w.Body.String()) + }) + } + + for _, dataType := range schemas { + t.Run("[upsert]httpCfg.allow: true, Accept-Type-Allow-Int64: false", func(t *testing.T) { + schema := newCollectionSchema(generateCollectionSchema(dataType, false)) + mp := mocks.NewMockProxy(t) + mp.EXPECT().DescribeCollection(mock.Anything, mock.Anything).Return(&milvuspb.DescribeCollectionResponse{ + CollectionName: DefaultCollectionName, + Schema: schema, + ShardsNum: ShardNumDefault, + Status: &StatusSuccess, + }, nil).Once() + mp.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ + Status: &StatusSuccess, + IDs: genIds(dataType), + UpsertCnt: 3, + }, nil).Once() + testEngine := initHTTPServer(mp, true) + rows := newSearchResult(generateSearchResult(dataType)) + data, _ := json.Marshal(map[string]interface{}{ + HTTPCollectionName: DefaultCollectionName, + HTTPReturnData: rows, + }) + bodyReader := bytes.NewReader(data) + req := httptest.NewRequest(http.MethodPost, versional(VectorUpsertPath), bodyReader) + req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) + req.Header.Set(HTTPHeaderAllowInt64, "false") w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, 200) - assert.Equal(t, w.Body.String(), PrintErr(merr.ErrInvalidInsertData)) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "{\"code\":200,\"data\":{\"upsertCount\":3,\"upsertIds\":[\"1\",\"2\",\"3\"]}}", w.Body.String()) }) } } func TestUpsert(t *testing.T) { paramtable.Init() + paramtable.Get().Save(proxy.Params.HTTPCfg.AcceptTypeAllowInt64.Key, "true") testCases := []testCase{} _, testCases = wrapWithDescribeColl(t, nil, ReturnFail, 1, testCases) _, testCases = wrapWithDescribeColl(t, nil, ReturnWrongStatus, 1, testCases) @@ -765,7 +1097,7 @@ func TestUpsert(t *testing.T) { mp2, _ = wrapWithDescribeColl(t, mp2, ReturnSuccess, 1, nil) mp2.EXPECT().Upsert(mock.Anything, mock.Anything).Return(nil, ErrDefault).Once() testCases = append(testCases, testCase{ - name: "insert fail", + name: "upsert fail", mp: mp2, exceptCode: 200, expectedBody: PrintErr(ErrDefault), @@ -778,7 +1110,7 @@ func TestUpsert(t *testing.T) { Status: merr.Status(err), }, nil).Once() testCases = append(testCases, testCase{ - name: "insert fail", + name: "upsert fail", mp: mp3, exceptCode: 200, expectedBody: PrintErr(err), @@ -790,17 +1122,17 @@ func TestUpsert(t *testing.T) { Status: &StatusSuccess, }, nil).Once() testCases = append(testCases, testCase{ - name: "id type invalid", - mp: mp4, - exceptCode: 200, - expectedBody: PrintErr(merr.ErrCheckPrimaryKey), + name: "id type invalid", + mp: mp4, + exceptCode: 200, + expectedErr: merr.ErrCheckPrimaryKey, }) mp5 := mocks.NewMockProxy(t) mp5, _ = wrapWithDescribeColl(t, mp5, ReturnSuccess, 1, nil) mp5.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ Status: &StatusSuccess, - IDs: getIntIds(), + IDs: genIds(schemapb.DataType_Int64), UpsertCnt: 3, }, nil).Once() testCases = append(testCases, testCase{ @@ -814,7 +1146,7 @@ func TestUpsert(t *testing.T) { mp6, _ = wrapWithDescribeColl(t, mp6, ReturnSuccess, 1, nil) mp6.EXPECT().Upsert(mock.Anything, mock.Anything).Return(&milvuspb.MutationResult{ Status: &StatusSuccess, - IDs: getStrIds(), + IDs: genIds(schemapb.DataType_VarChar), UpsertCnt: 3, }, nil).Once() testCases = append(testCases, testCase{ @@ -824,7 +1156,7 @@ func TestUpsert(t *testing.T) { expectedBody: "{\"code\":200,\"data\":{\"upsertCount\":3,\"upsertIds\":[\"1\",\"2\",\"3\"]}}", }) - rows := generateSearchResult() + rows := generateSearchResult(schemapb.DataType_Int64) data, _ := json.Marshal(map[string]interface{}{ HTTPCollectionName: DefaultCollectionName, HTTPReturnData: rows[0], @@ -837,11 +1169,15 @@ func TestUpsert(t *testing.T) { req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, tt.exceptCode) - assert.Equal(t, w.Body.String(), tt.expectedBody) + assert.Equal(t, tt.exceptCode, w.Code) + if tt.expectedErr != nil { + assert.Equal(t, true, CheckErrCode(w.Body.String(), tt.expectedErr)) + } else { + assert.Equal(t, tt.expectedBody, w.Body.String()) + } resp := map[string]interface{}{} err := json.Unmarshal(w.Body.Bytes(), &resp) - assert.Equal(t, err, nil) + assert.Equal(t, nil, err) }) } @@ -854,38 +1190,21 @@ func TestUpsert(t *testing.T) { req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, 200) - assert.Equal(t, w.Body.String(), PrintErr(merr.ErrInvalidInsertData)) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, true, CheckErrCode(w.Body.String(), merr.ErrInvalidInsertData)) resp := map[string]interface{}{} err := json.Unmarshal(w.Body.Bytes(), &resp) - assert.Equal(t, err, nil) + assert.Equal(t, nil, err) }) } -func getIntIds() *schemapb.IDs { - ids := schemapb.IDs{ - IdField: &schemapb.IDs_IntId{ - IntId: &schemapb.LongArray{ - Data: []int64{1, 2, 3}, - }, - }, - } - return &ids -} - -func getStrIds() *schemapb.IDs { - ids := schemapb.IDs{ - IdField: &schemapb.IDs_StrId{ - StrId: &schemapb.StringArray{ - Data: []string{"1", "2", "3"}, - }, - }, - } - return &ids +func genIds(dataType schemapb.DataType) *schemapb.IDs { + return generateIds(dataType, 3) } func TestSearch(t *testing.T) { paramtable.Init() + paramtable.Get().Save(proxy.Params.HTTPCfg.AcceptTypeAllowInt64.Key, "true") testCases := []testCase{} mp2 := mocks.NewMockProxy(t) @@ -911,6 +1230,22 @@ func TestSearch(t *testing.T) { mp4 := mocks.NewMockProxy(t) mp4.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{ + Status: &StatusSuccess, + Results: &schemapb.SearchResultData{ + FieldsData: []*schemapb.FieldData{}, + Scores: []float32{}, + TopK: 0, + }, + }, nil).Once() + testCases = append(testCases, testCase{ + name: "search success", + mp: mp4, + exceptCode: 200, + expectedBody: "{\"code\":200,\"data\":[]}", + }) + + mp5 := mocks.NewMockProxy(t) + mp5.EXPECT().Search(mock.Anything, mock.Anything).Return(&milvuspb.SearchResults{ Status: &StatusSuccess, Results: &schemapb.SearchResultData{ FieldsData: generateFieldData(), @@ -920,7 +1255,7 @@ func TestSearch(t *testing.T) { }, nil).Once() testCases = append(testCases, testCase{ name: "search success", - mp: mp4, + mp: mp5, exceptCode: 200, expectedBody: "{\"code\":200,\"data\":[{\"book_id\":1,\"book_intro\":[0.1,0.11],\"distance\":0.01,\"word_count\":1000},{\"book_id\":2,\"book_intro\":[0.2,0.22],\"distance\":0.04,\"word_count\":2000},{\"book_id\":3,\"book_intro\":[0.3,0.33],\"distance\":0.09,\"word_count\":3000}]}", }) @@ -938,16 +1273,20 @@ func TestSearch(t *testing.T) { req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, tt.exceptCode) - assert.Equal(t, w.Body.String(), tt.expectedBody) + assert.Equal(t, tt.exceptCode, w.Code) + if tt.expectedErr != nil { + assert.Equal(t, true, CheckErrCode(w.Body.String(), tt.expectedErr)) + } else { + assert.Equal(t, tt.expectedBody, w.Body.String()) + } resp := map[string]interface{}{} err := json.Unmarshal(w.Body.Bytes(), &resp) - assert.Equal(t, err, nil) + assert.Equal(t, nil, err) if resp[HTTPReturnCode] == float64(200) { data := resp[HTTPReturnData].([]interface{}) rows := generateQueryResult64(true) for i, row := range data { - assert.Equal(t, compareRow64(row.(map[string]interface{}), rows[i]), true) + assert.Equal(t, true, compareRow64(row.(map[string]interface{}), rows[i])) } } }) @@ -1020,10 +1359,10 @@ func wrapWithHasCollection(t *testing.T, mp *mocks.MockProxy, returnType ReturnT case ReturnFalse: call = mp.EXPECT().HasCollection(mock.Anything, mock.Anything).Return(&DefaultFalseResp, nil) testcase = testCase{ - name: "[share] collection not found", - mp: mp, - exceptCode: 200, - expectedBody: PrintErr(merr.ErrCollectionNotFound), + name: "[share] collection not found", + mp: mp, + exceptCode: 200, + expectedErr: merr.ErrCollectionNotFound, } case ReturnFail: call = mp.EXPECT().HasCollection(mock.Anything, mock.Anything).Return(nil, ErrDefault) @@ -1058,11 +1397,11 @@ func wrapWithHasCollection(t *testing.T, mp *mocks.MockProxy, returnType ReturnT func TestHttpRequestFormat(t *testing.T) { paramtable.Init() - errStrs := []string{ - PrintErr(merr.ErrIncorrectParameterFormat), - PrintErr(merr.ErrMissingRequiredParameters), - PrintErr(merr.ErrMissingRequiredParameters), - PrintErr(merr.ErrMissingRequiredParameters), + errStrs := []error{ + merr.ErrIncorrectParameterFormat, + merr.ErrMissingRequiredParameters, + merr.ErrMissingRequiredParameters, + merr.ErrMissingRequiredParameters, } requestJsons := [][]byte{ []byte(`{"collectionName": {"` + DefaultCollectionName + `", "dimension": 2}`), @@ -1108,8 +1447,8 @@ func TestHttpRequestFormat(t *testing.T) { req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, 200) - assert.Equal(t, w.Body.String(), errStrs[i]) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, true, CheckErrCode(w.Body.String(), errStrs[i])) }) } } @@ -1140,8 +1479,8 @@ func TestAuthorization(t *testing.T) { req.Header.Set("authorization", "Bearer test:test") w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, http.StatusForbidden) - assert.Equal(t, w.Body.String(), res) + assert.Equal(t, http.StatusForbidden, w.Code) + assert.Equal(t, res, w.Body.String()) }) } } @@ -1161,8 +1500,8 @@ func TestAuthorization(t *testing.T) { req.Header.Set("authorization", "Bearer test:test") w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, http.StatusForbidden) - assert.Equal(t, w.Body.String(), res) + assert.Equal(t, http.StatusForbidden, w.Code) + assert.Equal(t, res, w.Body.String()) }) } } @@ -1182,8 +1521,8 @@ func TestAuthorization(t *testing.T) { req.Header.Set("authorization", "Bearer test:test") w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, http.StatusForbidden) - assert.Equal(t, w.Body.String(), res) + assert.Equal(t, http.StatusForbidden, w.Code) + assert.Equal(t, res, w.Body.String()) }) } } @@ -1203,8 +1542,8 @@ func TestAuthorization(t *testing.T) { req.Header.Set("authorization", "Bearer test:test") w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, http.StatusForbidden) - assert.Equal(t, w.Body.String(), res) + assert.Equal(t, http.StatusForbidden, w.Code) + assert.Equal(t, res, w.Body.String()) }) } } @@ -1224,8 +1563,8 @@ func TestAuthorization(t *testing.T) { req.Header.Set("authorization", "Bearer test:test") w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, http.StatusForbidden) - assert.Equal(t, w.Body.String(), res) + assert.Equal(t, http.StatusForbidden, w.Code) + assert.Equal(t, res, w.Body.String()) }) } } @@ -1242,8 +1581,8 @@ func TestDatabaseNotFound(t *testing.T) { req.Header.Set("authorization", "Bearer root:Milvus") w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, http.StatusOK) - assert.Equal(t, w.Body.String(), PrintErr(ErrDefault)) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, true, CheckErrCode(w.Body.String(), ErrDefault)) }) t.Run("list database without success code", func(t *testing.T) { @@ -1257,8 +1596,8 @@ func TestDatabaseNotFound(t *testing.T) { req.Header.Set("authorization", "Bearer root:Milvus") w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, http.StatusOK) - assert.Equal(t, w.Body.String(), PrintErr(err)) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, true, CheckErrCode(w.Body.String(), err)) }) t.Run("list database success", func(t *testing.T) { @@ -1278,13 +1617,13 @@ func TestDatabaseNotFound(t *testing.T) { req.Header.Set("authorization", "Bearer root:Milvus") w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, http.StatusOK) - assert.Equal(t, w.Body.String(), "{\"code\":200,\"data\":[]}") + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "{\"code\":200,\"data\":[]}", w.Body.String()) }) - errorStr := PrintErr(merr.ErrDatabaseNotFound) - paths := map[string][]string{ - errorStr: { + theError := merr.ErrDatabaseNotFound + paths := map[error][]string{ + theError: { versional(VectorCollectionsPath) + "?dbName=test", versional(VectorCollectionsDescribePath) + "?dbName=test&collectionName=" + DefaultCollectionName, }, @@ -1302,14 +1641,14 @@ func TestDatabaseNotFound(t *testing.T) { req.Header.Set("authorization", "Bearer root:Milvus") w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, http.StatusOK) - assert.Equal(t, w.Body.String(), res) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, true, CheckErrCode(w.Body.String(), res)) }) } } requestBody := `{"dbName": "test", "collectionName": "` + DefaultCollectionName + `", "vector": [0.1, 0.2], "filter": "id in [2]", "id": [2], "dimension": 2, "data":[{"book_id":1,"book_intro":[0.1,0.11],"distance":0.01,"word_count":1000},{"book_id":2,"book_intro":[0.2,0.22],"distance":0.04,"word_count":2000},{"book_id":3,"book_intro":[0.3,0.33],"distance":0.09,"word_count":3000}]}` - paths = map[string][]string{ + pathArray := map[string][]string{ requestBody: { versional(VectorCollectionsCreatePath), versional(VectorCollectionsDropPath), @@ -1321,7 +1660,7 @@ func TestDatabaseNotFound(t *testing.T) { versional(VectorSearchPath), }, } - for request, pathArr := range paths { + for request, pathArr := range pathArray { for _, path := range pathArr { t.Run("POST dbName", func(t *testing.T) { mp := mocks.NewMockProxy(t) @@ -1335,8 +1674,8 @@ func TestDatabaseNotFound(t *testing.T) { req.Header.Set("authorization", "Bearer root:Milvus") w := httptest.NewRecorder() testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, http.StatusOK) - assert.Equal(t, w.Body.String(), errorStr) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, true, CheckErrCode(w.Body.String(), theError)) }) } } @@ -1455,159 +1794,3 @@ func getFieldSchema() []*schemapb.FieldSchema { return fields } - -func Test_Handles_VectorCollectionsDescribe(t *testing.T) { - paramtable.Init() - mp := mocks.NewMockProxy(t) - h := NewHandlers(mp) - testEngine := gin.New() - app := testEngine.Group("/", func(c *gin.Context) { - c.Set(ContextUsername, "") - username, _, ok := ParseUsernamePassword(c) - if ok { - c.Set(ContextUsername, username) - } - }) - h.RegisterRoutesToV1(app) - - t.Run("hasn't authenticate", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/vector/collections/describe?collectionName=book", nil) - w := httptest.NewRecorder() - testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, http.StatusUnauthorized) - }) - - t.Run("auth fail", func(t *testing.T) { - paramtable.Get().Save(proxy.Params.CommonCfg.AuthorizationEnabled.Key, "true") - req := httptest.NewRequest(http.MethodGet, "/vector/collections/describe?collectionName=book", nil) - req.SetBasicAuth("test", util.DefaultRootPassword) - w := httptest.NewRecorder() - testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, http.StatusForbidden) - assert.Equal(t, w.Body.String(), Print(merr.Code(merr.ErrServiceUnavailable), "internal: Milvus Proxy is not ready yet. please wait: service unavailable")) - }) - - t.Run("describe collection fail with error", func(t *testing.T) { - paramtable.Get().Save(proxy.Params.CommonCfg.AuthorizationEnabled.Key, "false") - mp.EXPECT(). - DescribeCollection(mock.Anything, mock.Anything). - Return(nil, ErrDefault). - Once() - req := httptest.NewRequest(http.MethodGet, "/vector/collections/describe?collectionName=book", nil) - req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) - w := httptest.NewRecorder() - testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, http.StatusOK) - assert.Equal(t, w.Body.String(), PrintErr(ErrDefault)) - }) - - t.Run("describe collection fail with status code", func(t *testing.T) { - err := merr.WrapErrDatabaseNotFound(DefaultDbName) - paramtable.Get().Save(proxy.Params.CommonCfg.AuthorizationEnabled.Key, "false") - mp.EXPECT(). - DescribeCollection(mock.Anything, mock.Anything). - Return(&milvuspb.DescribeCollectionResponse{ - Status: merr.Status(err), - }, nil). - Once() - req := httptest.NewRequest(http.MethodGet, "/vector/collections/describe?collectionName=book", nil) - req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) - w := httptest.NewRecorder() - testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, http.StatusOK) - assert.Equal(t, w.Body.String(), PrintErr(err)) - }) - - t.Run("get load state and describe index fail with error", func(t *testing.T) { - mp.EXPECT(). - DescribeCollection(mock.Anything, mock.Anything). - Return(&milvuspb.DescribeCollectionResponse{ - Schema: getCollectionSchema("collectionName"), - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, - }, nil). - Once() - mp.EXPECT(). - GetLoadState(mock.Anything, mock.Anything). - Return(nil, errors.New("error")). - Once() - mp.EXPECT(). - DescribeIndex(mock.Anything, mock.Anything). - Return(nil, errors.New("error")). - Once() - req := httptest.NewRequest(http.MethodGet, "/vector/collections/describe?collectionName=book", nil) - req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) - w := httptest.NewRecorder() - testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, http.StatusOK) - assert.Equal(t, w.Body.String(), "{\"code\":200,\"data\":{\"collectionName\":\"\",\"description\":\"\",\"enableDynamic\":false,\"fields\":[{\"autoId\":false,\"description\":\"RowID field\",\"name\":\"RowID\",\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"Timestamp field\",\"name\":\"Timestamp\",\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"field 100\",\"name\":\"float_vector_field\",\"primaryKey\":false,\"type\":\"FloatVector(2)\"},{\"autoId\":false,\"description\":\"field 106\",\"name\":\"int64_field\",\"primaryKey\":true,\"type\":\"Int64\"}],\"indexes\":[],\"load\":\"\",\"shardsNum\":0}}") - }) - - t.Run("get load state and describe index fail with status code", func(t *testing.T) { - mp.EXPECT(). - DescribeCollection(mock.Anything, mock.Anything). - Return(&milvuspb.DescribeCollectionResponse{ - Schema: getCollectionSchema("collectionName"), - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, - }, nil). - Once() - mp.EXPECT(). - GetLoadState(mock.Anything, mock.Anything). - Return(&milvuspb.GetLoadStateResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, - }, nil). - Once() - mp.EXPECT(). - DescribeIndex(mock.Anything, mock.Anything). - Return(&milvuspb.DescribeIndexResponse{ - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, - }, nil). - Once() - req := httptest.NewRequest(http.MethodGet, "/vector/collections/describe?collectionName=book", nil) - req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) - w := httptest.NewRecorder() - testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, http.StatusOK) - assert.Equal(t, w.Body.String(), "{\"code\":200,\"data\":{\"collectionName\":\"\",\"description\":\"\",\"enableDynamic\":false,\"fields\":[{\"autoId\":false,\"description\":\"RowID field\",\"name\":\"RowID\",\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"Timestamp field\",\"name\":\"Timestamp\",\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"field 100\",\"name\":\"float_vector_field\",\"primaryKey\":false,\"type\":\"FloatVector(2)\"},{\"autoId\":false,\"description\":\"field 106\",\"name\":\"int64_field\",\"primaryKey\":true,\"type\":\"Int64\"}],\"indexes\":[],\"load\":\"\",\"shardsNum\":0}}") - }) - - t.Run("ok", func(t *testing.T) { - mp.EXPECT(). - DescribeCollection(mock.Anything, mock.Anything). - Return(&milvuspb.DescribeCollectionResponse{ - Schema: getCollectionSchema("collectionName"), - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, - }, nil). - Once() - mp.EXPECT(). - GetLoadState(mock.Anything, mock.Anything). - Return(&milvuspb.GetLoadStateResponse{ - State: commonpb.LoadState_LoadStateLoaded, - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, - }, nil). - Once() - mp.EXPECT(). - DescribeIndex(mock.Anything, mock.Anything). - Return(&milvuspb.DescribeIndexResponse{ - IndexDescriptions: []*milvuspb.IndexDescription{ - { - IndexName: "in", - FieldName: "fn", - Params: []*commonpb.KeyValuePair{ - { - Key: "metric_type", - Value: "L2", - }, - }, - }, - }, - Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_UnexpectedError}, - }, nil). - Once() - req := httptest.NewRequest(http.MethodGet, "/vector/collections/describe?collectionName=book", nil) - req.SetBasicAuth(util.UserRoot, util.DefaultRootPassword) - w := httptest.NewRecorder() - testEngine.ServeHTTP(w, req) - assert.Equal(t, w.Code, http.StatusOK) - assert.Equal(t, w.Body.String(), "{\"code\":200,\"data\":{\"collectionName\":\"\",\"description\":\"\",\"enableDynamic\":false,\"fields\":[{\"autoId\":false,\"description\":\"RowID field\",\"name\":\"RowID\",\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"Timestamp field\",\"name\":\"Timestamp\",\"primaryKey\":false,\"type\":\"Int64\"},{\"autoId\":false,\"description\":\"field 100\",\"name\":\"float_vector_field\",\"primaryKey\":false,\"type\":\"FloatVector(2)\"},{\"autoId\":false,\"description\":\"field 106\",\"name\":\"int64_field\",\"primaryKey\":true,\"type\":\"Int64\"}],\"indexes\":[],\"load\":\"LoadStateLoaded\",\"shardsNum\":0}}") - }) -} diff --git a/internal/distributed/proxy/httpserver/utils.go b/internal/distributed/proxy/httpserver/utils.go index 0a502d3344cbd..5789fc5a79766 100644 --- a/internal/distributed/proxy/httpserver/utils.go +++ b/internal/distributed/proxy/httpserver/utils.go @@ -724,7 +724,7 @@ func genDynamicFields(fields []string, list []*schemapb.FieldData) []string { return dynamicFields } -func buildQueryResp(rowsNum int64, needFields []string, fieldDataList []*schemapb.FieldData, ids *schemapb.IDs, scores []float32) ([]map[string]interface{}, error) { +func buildQueryResp(rowsNum int64, needFields []string, fieldDataList []*schemapb.FieldData, ids *schemapb.IDs, scores []float32, enableInt64 bool) ([]map[string]interface{}, error) { var queryResp []map[string]interface{} columnNum := len(fieldDataList) @@ -791,7 +791,11 @@ func buildQueryResp(rowsNum int64, needFields []string, fieldDataList []*schemap case schemapb.DataType_Int32: row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetIntData().Data[i] case schemapb.DataType_Int64: - row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetLongData().Data[i] + if enableInt64 { + row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetLongData().Data[i] + } else { + row[fieldDataList[j].FieldName] = strconv.FormatInt(fieldDataList[j].GetScalars().GetLongData().Data[i], 10) + } case schemapb.DataType_Float: row[fieldDataList[j].FieldName] = fieldDataList[j].GetScalars().GetFloatData().Data[i] case schemapb.DataType_Double: @@ -840,7 +844,11 @@ func buildQueryResp(rowsNum int64, needFields []string, fieldDataList []*schemap switch ids.IdField.(type) { case *schemapb.IDs_IntId: int64Pks := ids.GetIntId().GetData() - row[DefaultPrimaryFieldName] = int64Pks[i] + if enableInt64 { + row[DefaultPrimaryFieldName] = int64Pks[i] + } else { + row[DefaultPrimaryFieldName] = strconv.FormatInt(int64Pks[i], 10) + } case *schemapb.IDs_StrId: stringPks := ids.GetStrId().GetData() row[DefaultPrimaryFieldName] = stringPks[i] @@ -856,3 +864,11 @@ func buildQueryResp(rowsNum int64, needFields []string, fieldDataList []*schemap return queryResp, nil } + +func formatInt64(intArray []int64) []string { + stringArray := make([]string, 0) + for _, i := range intArray { + stringArray = append(stringArray, strconv.FormatInt(i, 10)) + } + return stringArray +} diff --git a/internal/distributed/proxy/httpserver/utils_test.go b/internal/distributed/proxy/httpserver/utils_test.go index abf61fe170846..5f512bbeac3f1 100644 --- a/internal/distributed/proxy/httpserver/utils_test.go +++ b/internal/distributed/proxy/httpserver/utils_test.go @@ -31,7 +31,7 @@ func generatePrimaryField(datatype schemapb.DataType) schemapb.FieldSchema { } } -func generateIds(num int) *schemapb.IDs { +func generateIds(dataType schemapb.DataType, num int) *schemapb.IDs { var intArray []int64 if num == 0 { intArray = []int64{} @@ -40,13 +40,26 @@ func generateIds(num int) *schemapb.IDs { intArray = append(intArray, i) } } - return &schemapb.IDs{ - IdField: &schemapb.IDs_IntId{ - IntId: &schemapb.LongArray{ - Data: intArray, + switch dataType { + case schemapb.DataType_Int64: + return &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: intArray, + }, }, - }, + } + case schemapb.DataType_VarChar: + stringArray := formatInt64(intArray) + return &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: stringArray, + }, + }, + } } + return nil } func generateVectorFieldSchema(useBinary bool) schemapb.FieldSchema { @@ -82,8 +95,8 @@ func generateVectorFieldSchema(useBinary bool) schemapb.FieldSchema { } } -func generateCollectionSchema(useBinary bool) *schemapb.CollectionSchema { - primaryField := generatePrimaryField(schemapb.DataType_Int64) +func generateCollectionSchema(datatype schemapb.DataType, useBinary bool) *schemapb.CollectionSchema { + primaryField := generatePrimaryField(datatype) vectorField := generateVectorFieldSchema(useBinary) return &schemapb.CollectionSchema{ Name: DefaultCollectionName, @@ -196,7 +209,7 @@ func generateFieldData() []*schemapb.FieldData { return []*schemapb.FieldData{&fieldData1, &fieldData2, &fieldData3} } -func generateSearchResult() []map[string]interface{} { +func generateSearchResult(dataType schemapb.DataType) []map[string]interface{} { row1 := map[string]interface{}{ DefaultPrimaryFieldName: int64(1), FieldBookID: int64(1), @@ -218,6 +231,11 @@ func generateSearchResult() []map[string]interface{} { FieldBookIntro: []float32{0.3, 0.33}, HTTPReturnDistance: float32(0.09), } + if dataType == schemapb.DataType_String { + row1[DefaultPrimaryFieldName] = "1" + row2[DefaultPrimaryFieldName] = "2" + row3[DefaultPrimaryFieldName] = "3" + } return []map[string]interface{}{row1, row2, row3} } @@ -246,9 +264,9 @@ func generateQueryResult64(withDistance bool) []map[string]interface{} { } func TestPrintCollectionDetails(t *testing.T) { - coll := generateCollectionSchema(false) + coll := generateCollectionSchema(schemapb.DataType_Int64, false) indexes := generateIndexes() - assert.Equal(t, printFields(coll.Fields), []gin.H{ + assert.Equal(t, []gin.H{ { HTTPReturnFieldName: FieldBookID, HTTPReturnFieldType: "Int64", @@ -270,23 +288,23 @@ func TestPrintCollectionDetails(t *testing.T) { HTTPReturnFieldAutoID: false, HTTPReturnDescription: "", }, - }) - assert.Equal(t, printIndexes(indexes), []gin.H{ + }, printFields(coll.Fields)) + assert.Equal(t, []gin.H{ { HTTPReturnIndexName: DefaultIndexName, HTTPReturnIndexField: FieldBookIntro, HTTPReturnIndexMetricsType: DefaultMetricType, }, - }) - assert.Equal(t, getMetricType(indexes[0].Params), DefaultMetricType) - assert.Equal(t, getMetricType(nil), DefaultMetricType) + }, printIndexes(indexes)) + assert.Equal(t, DefaultMetricType, getMetricType(indexes[0].Params)) + assert.Equal(t, DefaultMetricType, getMetricType(nil)) fields := []*schemapb.FieldSchema{} for _, field := range newCollectionSchema(coll).Fields { if field.DataType == schemapb.DataType_VarChar { fields = append(fields, field) } } - assert.Equal(t, printFields(fields), []gin.H{ + assert.Equal(t, []gin.H{ { HTTPReturnFieldName: "field-varchar", HTTPReturnFieldType: "VarChar(10)", @@ -294,65 +312,65 @@ func TestPrintCollectionDetails(t *testing.T) { HTTPReturnFieldAutoID: false, HTTPReturnDescription: "", }, - }) + }, printFields(fields)) } func TestPrimaryField(t *testing.T) { - coll := generateCollectionSchema(false) + coll := generateCollectionSchema(schemapb.DataType_Int64, false) primaryField := generatePrimaryField(schemapb.DataType_Int64) field, ok := getPrimaryField(coll) - assert.Equal(t, ok, true) - assert.Equal(t, *field, primaryField) + assert.Equal(t, true, ok) + assert.Equal(t, primaryField, *field) - assert.Equal(t, joinArray([]int64{1, 2, 3}), "1,2,3") - assert.Equal(t, joinArray([]string{"1", "2", "3"}), "1,2,3") + assert.Equal(t, "1,2,3", joinArray([]int64{1, 2, 3})) + assert.Equal(t, "1,2,3", joinArray([]string{"1", "2", "3"})) jsonStr := "{\"id\": [1, 2, 3]}" idStr := gjson.Get(jsonStr, "id") rangeStr, err := convertRange(&primaryField, idStr) - assert.Equal(t, err, nil) - assert.Equal(t, rangeStr, "1,2,3") + assert.Equal(t, nil, err) + assert.Equal(t, "1,2,3", rangeStr) filter, err := checkGetPrimaryKey(coll, idStr) - assert.Equal(t, err, nil) - assert.Equal(t, filter, "book_id in [1,2,3]") + assert.Equal(t, nil, err) + assert.Equal(t, "book_id in [1,2,3]", filter) primaryField = generatePrimaryField(schemapb.DataType_VarChar) jsonStr = "{\"id\": [\"1\", \"2\", \"3\"]}" idStr = gjson.Get(jsonStr, "id") rangeStr, err = convertRange(&primaryField, idStr) - assert.Equal(t, err, nil) - assert.Equal(t, rangeStr, "1,2,3") + assert.Equal(t, nil, err) + assert.Equal(t, "1,2,3", rangeStr) filter, err = checkGetPrimaryKey(coll, idStr) - assert.Equal(t, err, nil) - assert.Equal(t, filter, "book_id in [1,2,3]") + assert.Equal(t, nil, err) + assert.Equal(t, "book_id in [1,2,3]", filter) } func TestInsertWithDynamicFields(t *testing.T) { body := "{\"data\": {\"id\": 0, \"book_id\": 1, \"book_intro\": [0.1, 0.2], \"word_count\": 2, \"classified\": false, \"databaseID\": null}}" req := InsertReq{} - coll := generateCollectionSchema(false) + coll := generateCollectionSchema(schemapb.DataType_Int64, false) var err error err, req.Data = checkAndSetData(body, &milvuspb.DescribeCollectionResponse{ Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success}, Schema: coll, }) - assert.Equal(t, err, nil) - assert.Equal(t, req.Data[0]["id"], int64(0)) - assert.Equal(t, req.Data[0]["book_id"], int64(1)) - assert.Equal(t, req.Data[0]["word_count"], int64(2)) + assert.Equal(t, nil, err) + assert.Equal(t, int64(0), req.Data[0]["id"]) + assert.Equal(t, int64(1), req.Data[0]["book_id"]) + assert.Equal(t, int64(2), req.Data[0]["word_count"]) fieldsData, err := anyToColumns(req.Data, coll) - assert.Equal(t, err, nil) - assert.Equal(t, fieldsData[len(fieldsData)-1].IsDynamic, true) - assert.Equal(t, fieldsData[len(fieldsData)-1].Type, schemapb.DataType_JSON) - assert.Equal(t, string(fieldsData[len(fieldsData)-1].GetScalars().GetJsonData().GetData()[0]), "{\"classified\":false,\"id\":0}") + assert.Equal(t, nil, err) + assert.Equal(t, true, fieldsData[len(fieldsData)-1].IsDynamic) + assert.Equal(t, schemapb.DataType_JSON, fieldsData[len(fieldsData)-1].Type) + assert.Equal(t, "{\"classified\":false,\"id\":0}", string(fieldsData[len(fieldsData)-1].GetScalars().GetJsonData().GetData()[0])) } func TestSerialize(t *testing.T) { parameters := []float32{0.11111, 0.22222} - // assert.Equal(t, string(serialize(parameters)), "\ufffd\ufffd\ufffd=\ufffd\ufffdc\u003e") - // assert.Equal(t, string(vector2PlaceholderGroupBytes(parameters)), "vector2PlaceholderGroupBytes") // todo - assert.Equal(t, string(serialize(parameters)), "\xa4\x8d\xe3=\xa4\x8dc>") - assert.Equal(t, string(vector2PlaceholderGroupBytes(parameters)), "\n\x10\n\x02$0\x10e\x1a\b\xa4\x8d\xe3=\xa4\x8dc>") // todo + // assert.Equal(t, "\ufffd\ufffd\ufffd=\ufffd\ufffdc\u003e", string(serialize(parameters))) + // assert.Equal(t, "vector2PlaceholderGroupBytes", string(vector2PlaceholderGroupBytes(parameters))) // todo + assert.Equal(t, "\xa4\x8d\xe3=\xa4\x8dc>", string(serialize(parameters))) + assert.Equal(t, "\n\x10\n\x02$0\x10e\x1a\b\xa4\x8d\xe3=\xa4\x8dc>", string(vector2PlaceholderGroupBytes(parameters))) // todo } func compareRow64(m1 map[string]interface{}, m2 map[string]interface{}) bool { @@ -438,10 +456,10 @@ func compareRows(row1 []map[string]interface{}, row2 []map[string]interface{}, c func TestBuildQueryResp(t *testing.T) { outputFields := []string{FieldBookID, FieldWordCount, "author", "date"} - rows, err := buildQueryResp(int64(0), outputFields, generateFieldData(), generateIds(3), []float32{0.01, 0.04, 0.09}) // []*schemapb.FieldData{&fieldData1, &fieldData2, &fieldData3} - assert.Equal(t, err, nil) - exceptRows := generateSearchResult() - assert.Equal(t, compareRows(rows, exceptRows, compareRow), true) + rows, err := buildQueryResp(int64(0), outputFields, generateFieldData(), generateIds(schemapb.DataType_Int64, 3), []float32{0.01, 0.04, 0.09}, true) // []*schemapb.FieldData{&fieldData1, &fieldData2, &fieldData3} + assert.Equal(t, nil, err) + exceptRows := generateSearchResult(schemapb.DataType_Int64) + assert.Equal(t, true, compareRows(rows, exceptRows, compareRow)) } func newCollectionSchema(coll *schemapb.CollectionSchema) *schemapb.CollectionSchema { @@ -776,19 +794,19 @@ func newSearchResult(results []map[string]interface{}) []map[string]interface{} } func TestAnyToColumn(t *testing.T) { - data, err := anyToColumns(newSearchResult(generateSearchResult()), newCollectionSchema(generateCollectionSchema(false))) - assert.Equal(t, err, nil) - assert.Equal(t, len(data), 13) + data, err := anyToColumns(newSearchResult(generateSearchResult(schemapb.DataType_Int64)), newCollectionSchema(generateCollectionSchema(schemapb.DataType_Int64, false))) + assert.Equal(t, nil, err) + assert.Equal(t, 13, len(data)) } func TestBuildQueryResps(t *testing.T) { outputFields := []string{"XXX", "YYY"} outputFieldsList := [][]string{outputFields, {"$meta"}, {"$meta", FieldBookID, FieldBookIntro, "YYY"}} for _, theOutputFields := range outputFieldsList { - rows, err := buildQueryResp(int64(0), theOutputFields, newFieldData(generateFieldData(), schemapb.DataType_None), generateIds(3), []float32{0.01, 0.04, 0.09}) - assert.Equal(t, err, nil) - exceptRows := newSearchResult(generateSearchResult()) - assert.Equal(t, compareRows(rows, exceptRows, compareRow), true) + rows, err := buildQueryResp(int64(0), theOutputFields, newFieldData(generateFieldData(), schemapb.DataType_None), generateIds(schemapb.DataType_Int64, 3), []float32{0.01, 0.04, 0.09}, true) + assert.Equal(t, nil, err) + exceptRows := newSearchResult(generateSearchResult(schemapb.DataType_Int64)) + assert.Equal(t, true, compareRows(rows, exceptRows, compareRow)) } dataTypes := []schemapb.DataType{ @@ -799,18 +817,29 @@ func TestBuildQueryResps(t *testing.T) { schemapb.DataType_JSON, schemapb.DataType_Array, } for _, dateType := range dataTypes { - _, err := buildQueryResp(int64(0), outputFields, newFieldData([]*schemapb.FieldData{}, dateType), generateIds(3), []float32{0.01, 0.04, 0.09}) - assert.Equal(t, err, nil) + _, err := buildQueryResp(int64(0), outputFields, newFieldData([]*schemapb.FieldData{}, dateType), generateIds(schemapb.DataType_Int64, 3), []float32{0.01, 0.04, 0.09}, true) + assert.Equal(t, nil, err) } - _, err := buildQueryResp(int64(0), outputFields, newFieldData([]*schemapb.FieldData{}, 1000), generateIds(3), []float32{0.01, 0.04, 0.09}) - assert.Equal(t, err.Error(), "the type(1000) of field(wrong-field-type) is not supported, use other sdk please") + _, err := buildQueryResp(int64(0), outputFields, newFieldData([]*schemapb.FieldData{}, 1000), generateIds(schemapb.DataType_Int64, 3), []float32{0.01, 0.04, 0.09}, true) + assert.Equal(t, "the type(1000) of field(wrong-field-type) is not supported, use other sdk please", err.Error()) + + res, err := buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIds(schemapb.DataType_Int64, 3), []float32{0.01, 0.04, 0.09}, true) + assert.Equal(t, 3, len(res)) + assert.Equal(t, nil, err) + + res, err = buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIds(schemapb.DataType_Int64, 3), []float32{0.01, 0.04, 0.09}, false) + assert.Equal(t, 3, len(res)) + assert.Equal(t, nil, err) + + res, err = buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIds(schemapb.DataType_VarChar, 3), []float32{0.01, 0.04, 0.09}, true) + assert.Equal(t, 3, len(res)) + assert.Equal(t, nil, err) - res, err := buildQueryResp(int64(0), outputFields, []*schemapb.FieldData{}, generateIds(3), []float32{0.01, 0.04, 0.09}) - assert.Equal(t, len(res), 3) - assert.Equal(t, err, nil) + _, err = buildQueryResp(int64(0), outputFields, generateFieldData(), generateIds(schemapb.DataType_Int64, 3), []float32{0.01, 0.04, 0.09}, false) + assert.Equal(t, nil, err) // len(rows) != len(scores), didn't show distance - _, err = buildQueryResp(int64(0), outputFields, newFieldData(generateFieldData(), schemapb.DataType_None), generateIds(3), []float32{0.01, 0.04}) - assert.Equal(t, err, nil) + _, err = buildQueryResp(int64(0), outputFields, newFieldData(generateFieldData(), schemapb.DataType_None), generateIds(schemapb.DataType_Int64, 3), []float32{0.01, 0.04}, true) + assert.Equal(t, nil, err) } diff --git a/internal/distributed/proxy/service.go b/internal/distributed/proxy/service.go index 594c3010e7692..94a0aab37390d 100644 --- a/internal/distributed/proxy/service.go +++ b/internal/distributed/proxy/service.go @@ -169,6 +169,15 @@ func (s *Server) startHTTPServer(errChan chan error) { defer s.wg.Done() ginHandler := gin.Default() ginHandler.Use(func(c *gin.Context) { + _, err := strconv.ParseBool(c.Request.Header.Get(httpserver.HTTPHeaderAllowInt64)) + if err != nil { + httpParams := ¶mtable.Get().HTTPCfg + if httpParams.AcceptTypeAllowInt64.GetAsBool() { + c.Request.Header.Set(httpserver.HTTPHeaderAllowInt64, "true") + } else { + c.Request.Header.Set(httpserver.HTTPHeaderAllowInt64, "false") + } + } c.Writer.Header().Set("Access-Control-Allow-Origin", "*") c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") c.Writer.Header().Set("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With") diff --git a/pkg/util/paramtable/http_param.go b/pkg/util/paramtable/http_param.go index ea04befbab209..829a3bfca63b0 100644 --- a/pkg/util/paramtable/http_param.go +++ b/pkg/util/paramtable/http_param.go @@ -1,9 +1,10 @@ package paramtable type httpConfig struct { - Enabled ParamItem `refreshable:"false"` - DebugMode ParamItem `refreshable:"false"` - Port ParamItem `refreshable:"false"` + Enabled ParamItem `refreshable:"false"` + DebugMode ParamItem `refreshable:"false"` + Port ParamItem `refreshable:"false"` + AcceptTypeAllowInt64 ParamItem `refreshable:"false"` } func (p *httpConfig) init(base *BaseTable) { @@ -33,4 +34,14 @@ func (p *httpConfig) init(base *BaseTable) { Export: true, } p.Port.Init(base.mgr) + + p.AcceptTypeAllowInt64 = ParamItem{ + Key: "proxy.http.acceptTypeAllowInt64", + DefaultValue: "false", + Version: "2.3.2", + Doc: "high-level restful api, whether http client can deal with int64", + PanicIfEmpty: false, + Export: true, + } + p.AcceptTypeAllowInt64.Init(base.mgr) } diff --git a/pkg/util/paramtable/http_param_test.go b/pkg/util/paramtable/http_param_test.go index 8e32450f34c7e..2e26ff339286c 100644 --- a/pkg/util/paramtable/http_param_test.go +++ b/pkg/util/paramtable/http_param_test.go @@ -13,4 +13,5 @@ func TestHTTPConfig_Init(t *testing.T) { assert.Equal(t, cfg.Enabled.GetAsBool(), true) assert.Equal(t, cfg.DebugMode.GetAsBool(), false) assert.Equal(t, cfg.Port.GetValue(), "") + assert.Equal(t, cfg.AcceptTypeAllowInt64.GetValue(), "false") }