Skip to content

Commit

Permalink
Update Handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
csenet committed Nov 22, 2023
1 parent 5383854 commit 20dff9e
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 31 deletions.
120 changes: 97 additions & 23 deletions backend/state-manager/pkg/connect/connect_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,29 @@ func (s *StateManagerServer) GetPointStates(
ctx context.Context,
req *connect.Request[statev1.GetPointStatesRequest],
) (*connect.Response[statev1.GetPointStatesResponse], error) {
err := connect.NewError(
connect.CodeUnknown,
errors.New("not implemented"),
)
return nil, err
blockStates, err := s.DBHandler.GetPoints()
if err != nil {
err = connect.NewError(
connect.CodeUnknown,
errors.New("db error"),
)
return nil, err
}

var response []*statev1.PointAndState

for _, pointState := range blockStates {
response = append(response, &statev1.PointAndState{
Id: pointState.Id,
State: pointState.State,
})
}

res := connect.NewResponse(&statev1.GetPointStatesResponse{
States: response,
})

return res, nil
}

/*
Expand Down Expand Up @@ -126,11 +144,30 @@ func (s *StateManagerServer) GetStopStates(
ctx context.Context,
req *connect.Request[statev1.GetStopStatesRequest],
) (*connect.Response[statev1.GetStopStatesResponse], error) {
err := connect.NewError(
connect.CodeUnknown,
errors.New("not implemented"),
)
return nil, err
stopStates, err := s.DBHandler.GetStops()
if err != nil {
err = connect.NewError(
connect.CodeUnknown,
errors.New("db error"),
)
slog.Default().Error("db connection error", err)
return nil, err
}

var response []*statev1.StopAndState

for _, stopState := range stopStates {
response = append(response, &statev1.StopAndState{
Id: stopState.Id,
State: stopState.State,
})
}

res := connect.NewResponse(&statev1.GetStopStatesResponse{
States: response,
})

return res, nil
}

/*
Expand All @@ -141,20 +178,57 @@ func (s *StateManagerServer) GetTrains(
ctx context.Context,
req *connect.Request[statev1.GetTrainsRequest],
) (*connect.Response[statev1.GetTrainsResponse], error) {
err := connect.NewError(
connect.CodeUnknown,
errors.New("not implemented"),
)
return nil, err
trains, err := s.DBHandler.GetTrains()
if err != nil {
err = connect.NewError(
connect.CodeUnknown,
errors.New("db error"),
)
slog.Default().Error("db connection error", err)
}
var response []*statev1.Train

for _, train := range trains {
response = append(response, &statev1.Train{
TrainId: train.TrainId,
PositionId: train.PositionId,
Priority: train.Priority,
})
}

res := connect.NewResponse(&statev1.GetTrainsResponse{
Trains: response,
})

return res, err
}

func (s *StateManagerServer) UpdateTrainUUID(
func (s *StateManagerServer) AddTrain(
ctx context.Context,
req *connect.Request[statev1.UpdateTrainUUIDRequest],
) (*connect.Response[statev1.UpdateTrainUUIDResponse], error) {
err := connect.NewError(
connect.CodeUnknown,
errors.New("not implemented"),
)
return nil, err
req *connect.Request[statev1.AddTrainRequest],
) (*connect.Response[statev1.AddTrainResponse], error) {
err := s.DBHandler.AddTrain(req.Msg.Train)
if err != nil {
err = connect.NewError(
connect.CodeUnknown,
errors.New("db error"),
)
slog.Default().Error("db connection error", err)
}
return connect.NewResponse(&statev1.AddTrainResponse{}), err
}

func (s *StateManagerServer) UpdateTrain(
ctx context.Context,
req *connect.Request[statev1.UpdateTrainRequest],
) (*connect.Response[statev1.UpdateTrainResponse], error) {
err := s.DBHandler.UpdateTrain(req.Msg.Train)
if err != nil {
err = connect.NewError(
connect.CodeUnknown,
errors.New("db error"),
)
slog.Default().Error("db connection error", err)
}
return connect.NewResponse(&statev1.UpdateTrainResponse{}), err
}
65 changes: 57 additions & 8 deletions backend/state-manager/pkg/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,18 @@ func (db *DBHandler) GetPoint(pointId string) (*statev1.PointAndState, error) {
return result, nil
}

func (db *DBHandler) GetPoints() []*statev1.PointAndState {
func (db *DBHandler) GetPoints() ([]*statev1.PointAndState, error) {
collection := db.stateManagerDB.Collection("points")
cursor, err := collection.Find(context.Background(), bson.M{})
if err != nil {
slog.Default().Warn("Get Points failed", slog.Any("err", err))
panic(err)
return nil, err
}
var result []*statev1.PointAndState
if err = cursor.All(context.Background(), &result); err != nil {
panic(err)
return nil, err
}
return result
return result, nil
}

/*
Expand Down Expand Up @@ -138,17 +138,17 @@ func (db *DBHandler) GetStop(stopId string) (*statev1.StopAndState, error) {
return result, nil
}

func (db *DBHandler) GetStops() []*statev1.StopAndState {
func (db *DBHandler) GetStops() ([]*statev1.StopAndState, error) {
collection := db.stateManagerDB.Collection("stops")
cursor, err := collection.Find(context.Background(), bson.M{})
if err != nil {
panic(err)
return nil, err
}
var result []*statev1.StopAndState
if err = cursor.All(context.Background(), &result); err != nil {
panic(err)
return nil, err
}
return result
return result, nil
}

/*
Expand Down Expand Up @@ -199,3 +199,52 @@ func (db *DBHandler) GetBlocks() ([]*statev1.BlockState, error) {
}
return result, nil
}

/*
Train
*/

func (db *DBHandler) AddTrain(train *statev1.Train) error {
collection := db.stateManagerDB.Collection("trains")
_, err := collection.InsertOne(context.Background(), train)
if err != nil {
return err
}
return nil
}

func (db *DBHandler) UpdateTrain(train *statev1.Train) error {
collection := db.stateManagerDB.Collection("trains")
_, err := collection.UpdateOne(
context.Background(),
bson.M{"trainid": train.TrainId},
bson.M{"$set": bson.M{"state": train}},
)
if err != nil {
return err
}
return nil
}

func (db *DBHandler) GetTrain(trainId string) (*statev1.Train, error) {
collection := db.stateManagerDB.Collection("trains")
var result *statev1.Train
err := collection.FindOne(context.Background(), bson.M{"trainid": trainId}).Decode(&result)
if err != nil {
return nil, err
}
return result, nil
}

func (db *DBHandler) GetTrains() ([]*statev1.Train, error) {
collection := db.stateManagerDB.Collection("trains")
cursor, err := collection.Find(context.Background(), bson.M{})
if err != nil {
return nil, err
}
var result []*statev1.Train
if err = cursor.All(context.Background(), &result); err != nil {
return nil, err
}
return result, nil
}

0 comments on commit 20dff9e

Please sign in to comment.