diff --git a/backend/state-manager/pkg/connect/connect_handler.go b/backend/state-manager/pkg/connect/connect_handler.go index 1100e736..ff705297 100644 --- a/backend/state-manager/pkg/connect/connect_handler.go +++ b/backend/state-manager/pkg/connect/connect_handler.go @@ -94,11 +94,29 @@ func (s *StateManagerServer) GetPointStates( ctx context.Context, req *connect.Request[statev1.GetPointStatesRequest], ) (*connect.Response[statev1.GetPointStatesResponse], error) { - err := connect.NewError( - connect.CodeUnknown, - errors.New("not implemented"), - ) - return nil, err + blockStates, err := s.DBHandler.GetPoints() + if err != nil { + err = connect.NewError( + connect.CodeUnknown, + errors.New("db error"), + ) + return nil, err + } + + var response []*statev1.PointAndState + + for _, pointState := range blockStates { + response = append(response, &statev1.PointAndState{ + Id: pointState.Id, + State: pointState.State, + }) + } + + res := connect.NewResponse(&statev1.GetPointStatesResponse{ + States: response, + }) + + return res, nil } /* @@ -126,11 +144,30 @@ func (s *StateManagerServer) GetStopStates( ctx context.Context, req *connect.Request[statev1.GetStopStatesRequest], ) (*connect.Response[statev1.GetStopStatesResponse], error) { - err := connect.NewError( - connect.CodeUnknown, - errors.New("not implemented"), - ) - return nil, err + stopStates, err := s.DBHandler.GetStops() + if err != nil { + err = connect.NewError( + connect.CodeUnknown, + errors.New("db error"), + ) + slog.Default().Error("db connection error", err) + return nil, err + } + + var response []*statev1.StopAndState + + for _, stopState := range stopStates { + response = append(response, &statev1.StopAndState{ + Id: stopState.Id, + State: stopState.State, + }) + } + + res := connect.NewResponse(&statev1.GetStopStatesResponse{ + States: response, + }) + + return res, nil } /* @@ -141,20 +178,57 @@ func (s *StateManagerServer) GetTrains( ctx context.Context, req *connect.Request[statev1.GetTrainsRequest], ) (*connect.Response[statev1.GetTrainsResponse], error) { - err := connect.NewError( - connect.CodeUnknown, - errors.New("not implemented"), - ) - return nil, err + trains, err := s.DBHandler.GetTrains() + if err != nil { + err = connect.NewError( + connect.CodeUnknown, + errors.New("db error"), + ) + slog.Default().Error("db connection error", err) + } + var response []*statev1.Train + + for _, train := range trains { + response = append(response, &statev1.Train{ + TrainId: train.TrainId, + PositionId: train.PositionId, + Priority: train.Priority, + }) + } + + res := connect.NewResponse(&statev1.GetTrainsResponse{ + Trains: response, + }) + + return res, err } -func (s *StateManagerServer) UpdateTrainUUID( +func (s *StateManagerServer) AddTrain( ctx context.Context, - req *connect.Request[statev1.UpdateTrainUUIDRequest], -) (*connect.Response[statev1.UpdateTrainUUIDResponse], error) { - err := connect.NewError( - connect.CodeUnknown, - errors.New("not implemented"), - ) - return nil, err + req *connect.Request[statev1.AddTrainRequest], +) (*connect.Response[statev1.AddTrainResponse], error) { + err := s.DBHandler.AddTrain(req.Msg.Train) + if err != nil { + err = connect.NewError( + connect.CodeUnknown, + errors.New("db error"), + ) + slog.Default().Error("db connection error", err) + } + return connect.NewResponse(&statev1.AddTrainResponse{}), err +} + +func (s *StateManagerServer) UpdateTrain( + ctx context.Context, + req *connect.Request[statev1.UpdateTrainRequest], +) (*connect.Response[statev1.UpdateTrainResponse], error) { + err := s.DBHandler.UpdateTrain(req.Msg.Train) + if err != nil { + err = connect.NewError( + connect.CodeUnknown, + errors.New("db error"), + ) + slog.Default().Error("db connection error", err) + } + return connect.NewResponse(&statev1.UpdateTrainResponse{}), err } diff --git a/backend/state-manager/pkg/db/db.go b/backend/state-manager/pkg/db/db.go index 80903aab..16e2db5c 100644 --- a/backend/state-manager/pkg/db/db.go +++ b/backend/state-manager/pkg/db/db.go @@ -86,18 +86,18 @@ func (db *DBHandler) GetPoint(pointId string) (*statev1.PointAndState, error) { return result, nil } -func (db *DBHandler) GetPoints() []*statev1.PointAndState { +func (db *DBHandler) GetPoints() ([]*statev1.PointAndState, error) { collection := db.stateManagerDB.Collection("points") cursor, err := collection.Find(context.Background(), bson.M{}) if err != nil { slog.Default().Warn("Get Points failed", slog.Any("err", err)) - panic(err) + return nil, err } var result []*statev1.PointAndState if err = cursor.All(context.Background(), &result); err != nil { - panic(err) + return nil, err } - return result + return result, nil } /* @@ -138,17 +138,17 @@ func (db *DBHandler) GetStop(stopId string) (*statev1.StopAndState, error) { return result, nil } -func (db *DBHandler) GetStops() []*statev1.StopAndState { +func (db *DBHandler) GetStops() ([]*statev1.StopAndState, error) { collection := db.stateManagerDB.Collection("stops") cursor, err := collection.Find(context.Background(), bson.M{}) if err != nil { - panic(err) + return nil, err } var result []*statev1.StopAndState if err = cursor.All(context.Background(), &result); err != nil { - panic(err) + return nil, err } - return result + return result, nil } /* @@ -199,3 +199,52 @@ func (db *DBHandler) GetBlocks() ([]*statev1.BlockState, error) { } return result, nil } + +/* +Train +*/ + +func (db *DBHandler) AddTrain(train *statev1.Train) error { + collection := db.stateManagerDB.Collection("trains") + _, err := collection.InsertOne(context.Background(), train) + if err != nil { + return err + } + return nil +} + +func (db *DBHandler) UpdateTrain(train *statev1.Train) error { + collection := db.stateManagerDB.Collection("trains") + _, err := collection.UpdateOne( + context.Background(), + bson.M{"trainid": train.TrainId}, + bson.M{"$set": bson.M{"state": train}}, + ) + if err != nil { + return err + } + return nil +} + +func (db *DBHandler) GetTrain(trainId string) (*statev1.Train, error) { + collection := db.stateManagerDB.Collection("trains") + var result *statev1.Train + err := collection.FindOne(context.Background(), bson.M{"trainid": trainId}).Decode(&result) + if err != nil { + return nil, err + } + return result, nil +} + +func (db *DBHandler) GetTrains() ([]*statev1.Train, error) { + collection := db.stateManagerDB.Collection("trains") + cursor, err := collection.Find(context.Background(), bson.M{}) + if err != nil { + return nil, err + } + var result []*statev1.Train + if err = cursor.All(context.Background(), &result); err != nil { + return nil, err + } + return result, nil +}