Skip to content

Commit

Permalink
[feature] validate uploaded images with the result on the Blockchain
Browse files Browse the repository at this point in the history
  • Loading branch information
lukewwww committed Aug 24, 2023
1 parent 8027a74 commit b6f9f15
Show file tree
Hide file tree
Showing 11 changed files with 216 additions and 101 deletions.
20 changes: 10 additions & 10 deletions api/v1/inference_tasks/create_task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ import (

func TestCreateTaskBeforeBlockchainConfirmation(t *testing.T) {

task := v1.PrepareRandomTask()
task := tests.PrepareRandomTask()

_, privateKeys, err := v1.PrepareAccounts()
_, privateKeys, err := tests.PrepareAccounts()
assert.Equal(t, nil, err, "prepare account error")

timestamp, signature, err := v1.SignData(task, privateKeys[0])
Expand All @@ -31,10 +31,10 @@ func TestCreateTaskBeforeBlockchainConfirmation(t *testing.T) {

func TestCreateTaskAfterBlockchainConfirmation(t *testing.T) {

addresses, privateKeys, err := v1.PrepareAccounts()
addresses, privateKeys, err := tests.PrepareAccounts()
assert.Equal(t, nil, err, "prepare account error")

taskInput, task, err := v1.PrepareBlockchainConfirmedTask(addresses, config.GetDB())
taskInput, task, err := tests.PrepareBlockchainConfirmedTask(addresses, config.GetDB())
assert.Equal(t, nil, err, "prepare task error")

timestamp, signature, err := v1.SignData(taskInput, privateKeys[0])
Expand All @@ -52,10 +52,10 @@ func TestCreateTaskAfterBlockchainConfirmation(t *testing.T) {
}

func TestCreateTaskUsingUnauthorizedAccount(t *testing.T) {
addresses, privateKeys, err := v1.PrepareAccounts()
addresses, privateKeys, err := tests.PrepareAccounts()
assert.Equal(t, nil, err, "prepare account error")

taskInput, _, err := v1.PrepareBlockchainConfirmedTask(addresses, config.GetDB())
taskInput, _, err := tests.PrepareBlockchainConfirmedTask(addresses, config.GetDB())
assert.Equal(t, nil, err, "prepare task error")

timestamp, signature, err := v1.SignData(taskInput, privateKeys[1])
Expand All @@ -68,10 +68,10 @@ func TestCreateTaskUsingUnauthorizedAccount(t *testing.T) {

func TestCreateDuplicateTask(t *testing.T) {

addresses, privateKeys, err := v1.PrepareAccounts()
addresses, privateKeys, err := tests.PrepareAccounts()
assert.Equal(t, nil, err, "prepare account error")

taskInput, task, err := v1.PrepareBlockchainConfirmedTask(addresses, config.GetDB())
taskInput, task, err := tests.PrepareBlockchainConfirmedTask(addresses, config.GetDB())
assert.Equal(t, nil, err, "prepare task error")

timestamp, signature, err := v1.SignData(taskInput, privateKeys[0])
Expand All @@ -94,10 +94,10 @@ func TestCreateDuplicateTask(t *testing.T) {
}

func TestCreateTaskWithMismatchedParamHash(t *testing.T) {
addresses, privateKeys, err := v1.PrepareAccounts()
addresses, privateKeys, err := tests.PrepareAccounts()
assert.Equal(t, nil, err, "prepare account error")

taskInput, _, err := v1.PrepareBlockchainConfirmedTask(addresses, config.GetDB())
taskInput, _, err := tests.PrepareBlockchainConfirmedTask(addresses, config.GetDB())
assert.Equal(t, nil, err, "prepare task error")

oldPrompt := taskInput.Prompt
Expand Down
7 changes: 3 additions & 4 deletions api/v1/inference_tasks/get_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ import (
)

type GetResultInput struct {
ImageNum string `path:"image_num" json:"image_num" description:"Image number" validate:"required"`
SelectedNode string `path:"selected_node" json:"selected_node" description:"Selected nodes" validate:"required"`
TaskId uint64 `path:"task_id" json:"task_id" description:"Task id" validate:"required"`
ImageNum string `path:"image_num" json:"image_num" description:"Image number" validate:"required"`
TaskId uint64 `path:"task_id" json:"task_id" description:"Task id" validate:"required"`
}

type GetResultInputWithSignature struct {
Expand Down Expand Up @@ -55,7 +54,7 @@ func GetResult(ctx *gin.Context, in *GetResultInputWithSignature) error {
imageFile := filepath.Join(
appConfig.DataDir.InferenceTasks,
task.GetTaskIdAsString(),
in.SelectedNode,
"results",
in.ImageNum+".png",
)

Expand Down
64 changes: 10 additions & 54 deletions api/v1/inference_tasks/get_result_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"h_relay/config"
"h_relay/tests"
v1 "h_relay/tests/api/v1"
"image/png"
"io"
"net/http"
"net/http/httptest"
Expand All @@ -17,27 +16,22 @@ import (
)

func TestUnauthorizedGetImage(t *testing.T) {
addresses, privateKeys, err := v1.PrepareAccounts()
addresses, privateKeys, err := tests.PrepareAccounts()
assert.Equal(t, nil, err, "prepare accounts error")

_, task, err := v1.PrepareParamsUploadedTask(addresses, config.GetDB())
_, task, err := tests.PrepareResultUploadedTask(addresses, config.GetDB())
assert.Equal(t, nil, err, "prepare task error")

err = prepareImagesForNode(task.GetTaskIdAsString(), addresses[1])
assert.Equal(t, nil, err, "create image error")

getResultInput := &inference_tasks.GetResultInput{
TaskId: task.TaskId,
SelectedNode: addresses[1],
ImageNum: "0",
TaskId: task.TaskId,
ImageNum: "0",
}

timestamp, signature, err := v1.SignData(getResultInput, privateKeys[1])
assert.Equal(t, nil, err, "sign data error")

r := callGetImageApi(
task.GetTaskIdAsString(),
addresses[1],
"0",
timestamp,
signature)
Expand All @@ -54,27 +48,22 @@ func TestUnauthorizedGetImage(t *testing.T) {

func TestGetImage(t *testing.T) {

addresses, privateKeys, err := v1.PrepareAccounts()
addresses, privateKeys, err := tests.PrepareAccounts()
assert.Equal(t, nil, err, "prepare accounts error")

_, task, err := v1.PrepareParamsUploadedTask(addresses, config.GetDB())
_, task, err := tests.PrepareResultUploadedTask(addresses, config.GetDB())
assert.Equal(t, nil, err, "prepare task error")

err = prepareImagesForNode(task.GetTaskIdAsString(), addresses[1])
assert.Equal(t, nil, err, "create image error")

getResultInput := &inference_tasks.GetResultInput{
TaskId: task.TaskId,
SelectedNode: addresses[1],
ImageNum: "2",
TaskId: task.TaskId,
ImageNum: "2",
}

timestamp, signature, err := v1.SignData(getResultInput, privateKeys[0])
assert.Equal(t, nil, err, "sign data error")

r := callGetImageApi(
task.GetTaskIdAsString(),
addresses[1],
"2",
timestamp,
signature)
Expand All @@ -85,7 +74,7 @@ func TestGetImage(t *testing.T) {
imageFolder := filepath.Join(
appConfig.DataDir.InferenceTasks,
task.GetTaskIdAsString(),
addresses[1],
"results",
)

out, err := os.Create(filepath.Join(imageFolder, "downloaded.png"))
Expand Down Expand Up @@ -115,12 +104,11 @@ func TestGetImage(t *testing.T) {

func callGetImageApi(
taskIdStr string,
nodeAddress string,
imageNum string,
timestamp int64,
signature string) *httptest.ResponseRecorder {

endpoint := "/v1/inference_tasks/" + taskIdStr + "/results/" + nodeAddress + "/" + imageNum
endpoint := "/v1/inference_tasks/" + taskIdStr + "/results/" + imageNum
query := "?timestamp=" + strconv.FormatInt(timestamp, 10) + "&signature=" + signature

req, _ := http.NewRequest("GET", endpoint+query, nil)
Expand All @@ -129,35 +117,3 @@ func callGetImageApi(

return w
}

func prepareImagesForNode(taskIdStr, nodeAddress string) error {
appConfig := config.GetConfig()

imageFolder := filepath.Join(
appConfig.DataDir.InferenceTasks,
taskIdStr,
nodeAddress,
)

if err := os.MkdirAll(imageFolder, os.ModeDir); err != nil {
return err
}

for i := 0; i < 5; i++ {
filename := filepath.Join(imageFolder, strconv.Itoa(i)+".png")
img := tests.CreateImage()
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE, 0777)
if err != nil {
return err
}

if err := png.Encode(f, img); err != nil {
return err
}

if err := f.Close(); err != nil {
return err
}
}
return nil
}
12 changes: 6 additions & 6 deletions api/v1/inference_tasks/get_task_by_id_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ import (
)

func TestGetBlockchainConfirmedTask(t *testing.T) {
addresses, privateKeys, err := v1.PrepareAccounts()
addresses, privateKeys, err := tests.PrepareAccounts()
assert.Equal(t, nil, err, "error preparing accounts")

taskInput, _, err := v1.PrepareBlockchainConfirmedTask(addresses, config.GetDB())
taskInput, _, err := tests.PrepareBlockchainConfirmedTask(addresses, config.GetDB())
assert.Equal(t, nil, err, "error preparing task")

getResultInput := inference_tasks.GetTaskInput{TaskId: taskInput.TaskId}
Expand All @@ -31,10 +31,10 @@ func TestGetBlockchainConfirmedTask(t *testing.T) {
}

func TestGetParamsUploadedTask(t *testing.T) {
addresses, privateKeys, err := v1.PrepareAccounts()
addresses, privateKeys, err := tests.PrepareAccounts()
assert.Equal(t, nil, err, "error preparing accounts")

taskInput, task, err := v1.PrepareParamsUploadedTask(addresses, config.GetDB())
taskInput, task, err := tests.PrepareParamsUploadedTask(addresses, config.GetDB())
assert.Equal(t, nil, err, "error preparing task")

getResultInput := inference_tasks.GetTaskInput{TaskId: taskInput.TaskId}
Expand All @@ -48,10 +48,10 @@ func TestGetParamsUploadedTask(t *testing.T) {
}

func TestGetUnauthorizedTask(t *testing.T) {
addresses, privateKeys, err := v1.PrepareAccounts()
addresses, privateKeys, err := tests.PrepareAccounts()
assert.Equal(t, nil, err, "error preparing accounts")

taskInput, _, err := v1.PrepareParamsUploadedTask(addresses, config.GetDB())
taskInput, _, err := tests.PrepareParamsUploadedTask(addresses, config.GetDB())
assert.Equal(t, nil, err, "error preparing task")

getResultInput := inference_tasks.GetTaskInput{TaskId: taskInput.TaskId}
Expand Down
60 changes: 51 additions & 9 deletions api/v1/inference_tasks/upload_result.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ package inference_tasks

import (
"errors"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
"gorm.io/gorm"
"h_relay/api/v1/response"
"h_relay/blockchain"
"h_relay/config"
"h_relay/models"
"os"
Expand Down Expand Up @@ -37,7 +40,7 @@ func UploadResult(ctx *gin.Context, in *ResultInputWithSignature) (*response.Res

var task models.InferenceTask

if result := config.GetDB().Where(&models.InferenceTask{TaskId: in.TaskId}).Preload("SelectedNodes").First(&task); result.Error != nil {
if result := config.GetDB().Where(&models.InferenceTask{TaskId: in.TaskId}).First(&task); result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
validationErr := response.NewValidationErrorResponse("task_id", "Task not found")
return nil, validationErr
Expand All @@ -46,37 +49,76 @@ func UploadResult(ctx *gin.Context, in *ResultInputWithSignature) (*response.Res
}
}

var selectedNodeAddress string
resultNode := &models.SelectedNode{
InferenceTaskID: task.ID,
IsResultSelected: true,
}

if err := config.GetDB().Where(resultNode).First(resultNode).Error; err != nil {

for _, selectedNode := range task.SelectedNodes {
if selectedNode.NodeAddress == address {
selectedNodeAddress = address
break
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, response.NewValidationErrorResponse("task_id", "Task not ready")
} else {
return nil, response.NewExceptionResponse(err)
}
}

if selectedNodeAddress == "" {
if resultNode.NodeAddress != address {
validationErr := response.NewValidationErrorResponse("signature", "Signer not allowed")
return nil, validationErr
}

form, _ := ctx.MultipartForm()
files := form.File["images"]

// Check whether the images are correct
var pHashBytes []byte

for _, file := range files {

imageFile, err := file.Open()

if err != nil {
return nil, response.NewExceptionResponse(err)
}

pHash, err := blockchain.GetPHashForImage(imageFile)

if err != nil {
return nil, response.NewExceptionResponse(err)
}

pHashBytes = append(pHashBytes, pHash...)

err = imageFile.Close()
if err != nil {
return nil, response.NewExceptionResponse(err)
}
}

uploadedResult := hexutil.Encode(pHashBytes)

log.Debugln("image compare: result from the blockchain: " + resultNode.Result)
log.Debugln("image compare: result from the uploaded file: " + uploadedResult)

if resultNode.Result != uploadedResult {
validationErr := response.NewValidationErrorResponse("images", "Wrong images uploaded")
return nil, validationErr
}

appConfig := config.GetConfig()

taskWorkspace := appConfig.DataDir.InferenceTasks
taskIdStr := task.GetTaskIdAsString()

taskDir := filepath.Join(taskWorkspace, taskIdStr, selectedNodeAddress)
taskDir := filepath.Join(taskWorkspace, taskIdStr, "results")
if err = os.MkdirAll(taskDir, os.ModeDir); err != nil {
return nil, response.NewExceptionResponse(err)
}

fileNum := 0

for _, file := range files {

filename := filepath.Join(taskDir, strconv.Itoa(fileNum)+".png")
if err := ctx.SaveUploadedFile(file, filename); err != nil {
return nil, response.NewExceptionResponse(err)
Expand Down
Loading

0 comments on commit b6f9f15

Please sign in to comment.