diff --git a/behavior_metrics/brains/CARLA/brain_carla_segmentation_based_imitation_learning.py b/behavior_metrics/brains/CARLA/brain_carla_segmentation_based_imitation_learning.py new file mode 100644 index 00000000..2f9e2b61 --- /dev/null +++ b/behavior_metrics/brains/CARLA/brain_carla_segmentation_based_imitation_learning.py @@ -0,0 +1,152 @@ +from torchvision import transforms +from PIL import Image +from brains.CARLA.utils.pilotnet_onehot import PilotNetOneHot +from brains.CARLA.utils.test_utils import traffic_light_to_int, model_control +from utils.constants import PRETRAINED_MODELS_DIR, ROOT_PATH +from os import path + +import numpy as np + +import torch +import torchvision +import cv2 +import time +import os +import math +import carla + +PRETRAINED_MODELS = ROOT_PATH + '/' + PRETRAINED_MODELS_DIR + 'CARLA/' + +class Brain: + + def __init__(self, sensors, actuators, model=None, handler=None, config=None): + self.motors = actuators.get_motor('motors_0') + self.camera_rgb = sensors.get_camera('camera_0') # rgb front view camera + self.camera_seg = sensors.get_camera('camera_2') # segmentation camera + self.handler = handler + self.inference_times = [] + self.gpu_inference = config['GPU'] + self.device = torch.device('cuda' if (torch.cuda.is_available() and self.gpu_inference) else 'cpu') + + client = carla.Client('localhost', 2000) + client.set_timeout(10.0) + world = client.get_world() + self.map = world.get_map() + + weather = carla.WeatherParameters.ClearNoon + world.set_weather(weather) + + self.vehicle = None + while self.vehicle is None: + for vehicle in world.get_actors().filter('vehicle.*'): + if vehicle.attributes.get('role_name') == 'ego_vehicle': + self.vehicle = vehicle + break + if self.vehicle is None: + print("Waiting for vehicle with role_name 'ego_vehicle'") + time.sleep(1) # sleep for 1 second before checking again + + if model: + if not path.exists(PRETRAINED_MODELS + model): + print("File " + model + " cannot be found in " + PRETRAINED_MODELS) + + if config['UseOptimized']: + self.net = torch.jit.load(PRETRAINED_MODELS + model).to(self.device) + else: + self.net = PilotNetOneHot((288, 200, 6), 3, 4, 4).to(self.device) + self.net.load_state_dict(torch.load(PRETRAINED_MODELS + model,map_location=self.device)) + self.net.eval() + + self.prev_hlc = 0 + + + def update_frame(self, frame_id, data): + """Update the information to be shown in one of the GUI's frames. + + Arguments: + frame_id {str} -- Id of the frame that will represent the data + data {*} -- Data to be shown in the frame. Depending on the type of frame (rgbimage, laser, pose3d, etc) + """ + if data.shape[0] != data.shape[1]: + if data.shape[0] > data.shape[1]: + difference = data.shape[0] - data.shape[1] + extra_left, extra_right = int(difference/2), int(difference/2) + extra_top, extra_bottom = 0, 0 + else: + difference = data.shape[1] - data.shape[0] + extra_left, extra_right = 0, 0 + extra_top, extra_bottom = int(difference/2), int(difference/2) + + + data = np.pad(data, ((extra_top, extra_bottom), (extra_left, extra_right), (0, 0)), mode='constant', constant_values=0) + + self.handler.update_frame(frame_id, data) + + def execute(self): + """Main loop of the brain. This will be called iteratively each TIME_CYCLE (see pilot.py)""" + + rgb_image = self.camera_rgb.getImage().data + seg_image = self.camera_seg.getImage().data + + self.update_frame('frame_0', rgb_image) + self.update_frame('frame_1', seg_image) + + try: + # calculate speed + speed_m_s = self.vehicle.get_velocity() + speed = 3.6 * math.sqrt(speed_m_s.x**2 + speed_m_s.y**2 + speed_m_s.z**2) + + # randomly choose high-level command if at junction + vehicle_location = self.vehicle.get_transform().location + vehicle_waypoint = self.map.get_waypoint(vehicle_location) + next_to_junction = False + for j in range(1, 11): + next_waypoint = vehicle_waypoint.next(j * 1.0)[0] + if next_waypoint.is_junction: + next_to_junction = True + next_waypoints = vehicle_waypoint.next(j * 1.0) + break + if vehicle_waypoint.is_junction or next_to_junction: + if self.prev_hlc == 0: + valid_turns = [] + for next_wp in next_waypoints: + yaw_diff = next_wp.transform.rotation.yaw - vehicle_waypoint.transform.rotation.yaw + yaw_diff = (yaw_diff + 180) % 360 - 180 + if -15 < yaw_diff < 15: + valid_turns.append(3) # Go Straight + elif 15 < yaw_diff < 165: + valid_turns.append(1) # Turn Left + elif -165 < yaw_diff < -15: + valid_turns.append(2) # Turn Right + hlc = np.random.choice(valid_turns) + else: + hlc = self.prev_hlc + else: + hlc = 0 + + # get traffic light status + light_status = -1 + if self.vehicle.is_at_traffic_light(): + traffic_light = self.vehicle.get_traffic_light() + light_status = traffic_light.get_state() + + print(f'hlc: {hlc}') + print(f'light: {light_status}') + frame_data = { + 'hlc': hlc, + 'measurements': speed, + 'rgb': np.copy(rgb_image), + 'segmentation': np.copy(seg_image), + 'light': np.array([traffic_light_to_int(light_status)]) + } + + throttle, steer, brake = model_control(self.net, + frame_data, + ignore_traffic_light=False, + device=self.device, + combined_control=False) + self.motors.sendThrottle(throttle) + self.motors.sendSteer(steer) + self.motors.sendBrake(brake) + except Exception as err: + print(err) \ No newline at end of file diff --git a/behavior_metrics/brains/CARLA/pytorch/utils/pilotnet_onehot.py b/behavior_metrics/brains/CARLA/pytorch/utils/pilotnet_onehot.py new file mode 100644 index 00000000..919fecc5 --- /dev/null +++ b/behavior_metrics/brains/CARLA/pytorch/utils/pilotnet_onehot.py @@ -0,0 +1,158 @@ +import torch +import torch.nn as nn +import sys +import os + +from .convlstm import ConvLSTM + + +class PilotNetOneHot(nn.Module): + def __init__(self, image_shape, num_labels, num_hlc, num_light): + super(PilotNetOneHot, self).__init__() + self.num_channels = image_shape[2] + self.cn_1 = nn.Conv2d(self.num_channels, 8, kernel_size=3, stride=2) + self.relu_1 = nn.ReLU() + self.cn_2 = nn.Conv2d(8, 8, kernel_size=3, stride=2) + self.relu_2 = nn.ReLU() + self.cn_3 = nn.Conv2d(8, 8, kernel_size=3, stride=2) + self.relu_3 = nn.ReLU() + self.dropout_1 = nn.Dropout(0.2) + + self.clstm_n = ConvLSTM(8, 8, (5, 5), 3, batch_first=True, bias=True, return_all_layers=False) + + self.fc_1 = nn.Linear(8*35*24+1+num_hlc+num_light, 50) # add embedding layer output size + self.relu_fc_1 = nn.ReLU() + self.fc_2 = nn.Linear(50, 10) + self.relu_fc_2 = nn.ReLU() + self.fc_3 = nn.Linear(10, num_labels) + + def forward(self, img, speed, hlc, light): + out = self.cn_1(img) + out = self.relu_1(out) + out = self.cn_2(out) + out = self.relu_2(out) + out = self.cn_3(out) + out = self.relu_3(out) + out = self.dropout_1(out) + out = out.unsqueeze(1) # add additional dimension at 1 + + _, last_states = self.clstm_n(out) + out = last_states[0][0] # 0 for layer index, 0 for h index + + # flatten & concatenate with speed + out = out.reshape(out.size(0), -1) + speed = speed.view(speed.size(0), -1) + hlc = hlc.view(hlc.size(0), -1) + light = light.view(light.size(0), -1) + out = torch.cat((out, speed, hlc, light), dim=1) # Concatenate the high-level commands to the rest of the inputs + + out = self.fc_1(out) + out = self.relu_fc_1(out) + out = self.fc_2(out) + out = self.relu_fc_2(out) + out = self.fc_3(out) + + out = torch.sigmoid(out) + + return out + +class PilotNetOneHotNoLight(nn.Module): + def __init__(self, image_shape, num_labels, num_hlc): + super(PilotNetOneHotNoLight, self).__init__() + self.num_channels = image_shape[2] + self.cn_1 = nn.Conv2d(self.num_channels, 8, kernel_size=3, stride=2) + self.relu_1 = nn.ReLU() + self.cn_2 = nn.Conv2d(8, 8, kernel_size=3, stride=2) + self.relu_2 = nn.ReLU() + self.cn_3 = nn.Conv2d(8, 8, kernel_size=3, stride=2) + self.relu_3 = nn.ReLU() + self.dropout_1 = nn.Dropout(0.2) + + self.clstm_n = ConvLSTM(8, 8, (5, 5), 3, batch_first=True, bias=True, return_all_layers=False) + + self.fc_1 = nn.Linear(8*35*24+1+num_hlc, 50) # add embedding layer output size + self.relu_fc_1 = nn.ReLU() + self.fc_2 = nn.Linear(50, 10) + self.relu_fc_2 = nn.ReLU() + self.fc_3 = nn.Linear(10, num_labels) + + def forward(self, img, speed, hlc): + out = self.cn_1(img) + out = self.relu_1(out) + out = self.cn_2(out) + out = self.relu_2(out) + out = self.cn_3(out) + out = self.relu_3(out) + out = self.dropout_1(out) + out = out.unsqueeze(1) # add additional dimension at 1 + + _, last_states = self.clstm_n(out) + out = last_states[0][0] # 0 for layer index, 0 for h index + + # flatten & concatenate with speed + out = out.reshape(out.size(0), -1) + speed = speed.view(speed.size(0), -1) + hlc = hlc.view(hlc.size(0), -1) + out = torch.cat((out, speed, hlc), dim=1) # Concatenate the high-level commands to the rest of the inputs + out = self.fc_1(out) + out = self.relu_fc_1(out) + out = self.fc_2(out) + out = self.relu_fc_2(out) + out = self.fc_3(out) + + out = torch.sigmoid(out) + + return out + + +class PilotNetEmbeddingNoLight(nn.Module): + def __init__(self, image_shape, num_labels, num_hlc): + super(PilotNetEmbeddingNoLight, self).__init__() + self.num_channels = image_shape[2] + self.cn_1 = nn.Conv2d(self.num_channels, 8, kernel_size=3, stride=2) + self.relu_1 = nn.ReLU() + self.cn_2 = nn.Conv2d(8, 8, kernel_size=3, stride=2) + self.relu_2 = nn.ReLU() + self.cn_3 = nn.Conv2d(8, 8, kernel_size=3, stride=2) + self.relu_3 = nn.ReLU() + self.dropout_1 = nn.Dropout(0.2) + + self.clstm_n = ConvLSTM(8, 8, (5, 5), 3, batch_first=True, bias=True, return_all_layers=False) + + # Add embedding layer for high-level commands + self.embedding = nn.Embedding(num_hlc, 5) + + self.fc_1 = nn.Linear(8*35*24+1+5, 50) # add embedding layer output size + self.relu_fc_1 = nn.ReLU() + self.fc_2 = nn.Linear(50, 10) + self.relu_fc_2 = nn.ReLU() + self.fc_3 = nn.Linear(10, num_labels) + + def forward(self, img, speed, hlc): + out = self.cn_1(img) + out = self.relu_1(out) + out = self.cn_2(out) + out = self.relu_2(out) + out = self.cn_3(out) + out = self.relu_3(out) + out = self.dropout_1(out) + out = out.unsqueeze(1) # add additional dimension at 1 + + _, last_states = self.clstm_n(out) + out = last_states[0][0] # 0 for layer index, 0 for h index + + # flatten & concatenate with speed + out = out.reshape(out.size(0), -1) + speed = speed.view(speed.size(0), -1) + hlc = self.embedding(hlc).view(hlc.size(0), -1) # convert hlc to dense vectors using the embedding layer + out = torch.cat((out, speed, hlc), dim=1) # Concatenate the high-level commands to the rest of the inputs + + out = self.fc_1(out) + out = self.relu_fc_1(out) + out = self.fc_2(out) + out = self.relu_fc_2(out) + out = self.fc_3(out) + + out = torch.sigmoid(out) + + return out \ No newline at end of file diff --git a/behavior_metrics/brains/CARLA/utils/convlstm.py b/behavior_metrics/brains/CARLA/utils/convlstm.py new file mode 100644 index 00000000..678caa74 --- /dev/null +++ b/behavior_metrics/brains/CARLA/utils/convlstm.py @@ -0,0 +1,195 @@ +""" +This implementation of Convolutional LSTM has been adapted from https://github.com/ndrplz/ConvLSTM_pytorch. +""" + +import torch.nn as nn +import torch + + +class ConvLSTMCell(nn.Module): + + def __init__(self, input_dim, hidden_dim, kernel_size, bias): + """ + Initialize ConvLSTM cell. + + Parameters + ---------- + input_dim: int + Number of channels of input tensor. + hidden_dim: int + Number of channels of hidden state. + kernel_size: (int, int) + Size of the convolutional kernel. + bias: bool + Whether or not to add the bias. + """ + + super(ConvLSTMCell, self).__init__() + + self.input_dim = input_dim + self.hidden_dim = hidden_dim + + self.kernel_size = kernel_size + self.padding = kernel_size[0] // 2, kernel_size[1] // 2 + self.bias = bias + + self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim, + out_channels=4 * self.hidden_dim, + kernel_size=self.kernel_size, + padding=self.padding, + bias=self.bias) + + def forward(self, input_tensor, cur_state): + h_cur, c_cur = cur_state + + combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis + + combined_conv = self.conv(combined) + cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) + i = torch.sigmoid(cc_i) + f = torch.sigmoid(cc_f) + o = torch.sigmoid(cc_o) + g = torch.tanh(cc_g) + + c_next = f * c_cur + i * g + h_next = o * torch.tanh(c_next) + + return h_next, c_next + + def init_hidden(self, batch_size, image_size): + height, width = image_size + return (torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device), + torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device)) + + +class ConvLSTM(nn.Module): + + """ + + Parameters: + input_dim: Number of channels in input + hidden_dim: Number of hidden channels + kernel_size: Size of kernel in convolutions + num_layers: Number of LSTM layers stacked on each other + batch_first: Whether or not dimension 0 is the batch or not + bias: Bias or no bias in Convolution + return_all_layers: Return the list of computations for all layers + Note: Will do same padding. + + Input: + A tensor of size B, T, C, H, W or T, B, C, H, W + Output: + A tuple of two lists of length num_layers (or length 1 if return_all_layers is False). + 0 - layer_output_list is the list of lists of length T of each output + 1 - last_state_list is the list of last states + each element of the list is a tuple (h, c) for hidden state and memory + Example: + >> x = torch.rand((32, 10, 64, 128, 128)) + >> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False) + >> _, last_states = convlstm(x) + >> h = last_states[0][0] # 0 for layer index, 0 for h index + """ + + def __init__(self, input_dim, hidden_dim, kernel_size, num_layers, + batch_first=False, bias=True, return_all_layers=False): + super(ConvLSTM, self).__init__() + + self._check_kernel_size_consistency(kernel_size) + + # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers + kernel_size = self._extend_for_multilayer(kernel_size, num_layers) + hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) + if not len(kernel_size) == len(hidden_dim) == num_layers: + raise ValueError('Inconsistent list length.') + + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.kernel_size = kernel_size + self.num_layers = num_layers + self.batch_first = batch_first + self.bias = bias + self.return_all_layers = return_all_layers + + cell_list = [] + for i in range(0, self.num_layers): + cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1] + + cell_list.append(ConvLSTMCell(input_dim=cur_input_dim, + hidden_dim=self.hidden_dim[i], + kernel_size=self.kernel_size[i], + bias=self.bias)) + + self.cell_list = nn.ModuleList(cell_list) + + def forward(self, input_tensor, hidden_state=None): + """ + + Parameters + ---------- + input_tensor: todo + 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w) + hidden_state: todo + None. todo implement stateful + + Returns + ------- + last_state_list, layer_output + """ + if not self.batch_first: + # (t, b, c, h, w) -> (b, t, c, h, w) + input_tensor = input_tensor.permute(1, 0, 2, 3, 4) + + b, _, _, h, w = input_tensor.size() + + # Implement stateful ConvLSTM + if hidden_state is not None: + raise NotImplementedError() + else: + # Since the init is done in forward. Can send image size here + hidden_state = self._init_hidden(batch_size=b, + image_size=(h, w)) + + layer_output_list = [] + last_state_list = [] + + seq_len = input_tensor.size(1) + cur_layer_input = input_tensor + + for layer_idx in range(self.num_layers): + + h, c = hidden_state[layer_idx] + output_inner = [] + for t in range(seq_len): + h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t, :, :, :], + cur_state=[h, c]) + output_inner.append(h) + + layer_output = torch.stack(output_inner, dim=1) + cur_layer_input = layer_output + + layer_output_list.append(layer_output) + last_state_list.append([h, c]) + + if not self.return_all_layers: + layer_output_list = layer_output_list[-1:] + last_state_list = last_state_list[-1:] + + return layer_output_list, last_state_list + + def _init_hidden(self, batch_size, image_size): + init_states = [] + for i in range(self.num_layers): + init_states.append(self.cell_list[i].init_hidden(batch_size, image_size)) + return init_states + + @staticmethod + def _check_kernel_size_consistency(kernel_size): + if not (isinstance(kernel_size, tuple) or + (isinstance(kernel_size, list) and all([isinstance(elem, tuple) for elem in kernel_size]))): + raise ValueError('`kernel_size` must be tuple or list of tuples') + + @staticmethod + def _extend_for_multilayer(param, num_layers): + if not isinstance(param, list): + param = [param] * num_layers + return param \ No newline at end of file diff --git a/behavior_metrics/brains/CARLA/utils/pilotnet_onehot.py b/behavior_metrics/brains/CARLA/utils/pilotnet_onehot.py new file mode 100644 index 00000000..919fecc5 --- /dev/null +++ b/behavior_metrics/brains/CARLA/utils/pilotnet_onehot.py @@ -0,0 +1,158 @@ +import torch +import torch.nn as nn +import sys +import os + +from .convlstm import ConvLSTM + + +class PilotNetOneHot(nn.Module): + def __init__(self, image_shape, num_labels, num_hlc, num_light): + super(PilotNetOneHot, self).__init__() + self.num_channels = image_shape[2] + self.cn_1 = nn.Conv2d(self.num_channels, 8, kernel_size=3, stride=2) + self.relu_1 = nn.ReLU() + self.cn_2 = nn.Conv2d(8, 8, kernel_size=3, stride=2) + self.relu_2 = nn.ReLU() + self.cn_3 = nn.Conv2d(8, 8, kernel_size=3, stride=2) + self.relu_3 = nn.ReLU() + self.dropout_1 = nn.Dropout(0.2) + + self.clstm_n = ConvLSTM(8, 8, (5, 5), 3, batch_first=True, bias=True, return_all_layers=False) + + self.fc_1 = nn.Linear(8*35*24+1+num_hlc+num_light, 50) # add embedding layer output size + self.relu_fc_1 = nn.ReLU() + self.fc_2 = nn.Linear(50, 10) + self.relu_fc_2 = nn.ReLU() + self.fc_3 = nn.Linear(10, num_labels) + + def forward(self, img, speed, hlc, light): + out = self.cn_1(img) + out = self.relu_1(out) + out = self.cn_2(out) + out = self.relu_2(out) + out = self.cn_3(out) + out = self.relu_3(out) + out = self.dropout_1(out) + out = out.unsqueeze(1) # add additional dimension at 1 + + _, last_states = self.clstm_n(out) + out = last_states[0][0] # 0 for layer index, 0 for h index + + # flatten & concatenate with speed + out = out.reshape(out.size(0), -1) + speed = speed.view(speed.size(0), -1) + hlc = hlc.view(hlc.size(0), -1) + light = light.view(light.size(0), -1) + out = torch.cat((out, speed, hlc, light), dim=1) # Concatenate the high-level commands to the rest of the inputs + + out = self.fc_1(out) + out = self.relu_fc_1(out) + out = self.fc_2(out) + out = self.relu_fc_2(out) + out = self.fc_3(out) + + out = torch.sigmoid(out) + + return out + +class PilotNetOneHotNoLight(nn.Module): + def __init__(self, image_shape, num_labels, num_hlc): + super(PilotNetOneHotNoLight, self).__init__() + self.num_channels = image_shape[2] + self.cn_1 = nn.Conv2d(self.num_channels, 8, kernel_size=3, stride=2) + self.relu_1 = nn.ReLU() + self.cn_2 = nn.Conv2d(8, 8, kernel_size=3, stride=2) + self.relu_2 = nn.ReLU() + self.cn_3 = nn.Conv2d(8, 8, kernel_size=3, stride=2) + self.relu_3 = nn.ReLU() + self.dropout_1 = nn.Dropout(0.2) + + self.clstm_n = ConvLSTM(8, 8, (5, 5), 3, batch_first=True, bias=True, return_all_layers=False) + + self.fc_1 = nn.Linear(8*35*24+1+num_hlc, 50) # add embedding layer output size + self.relu_fc_1 = nn.ReLU() + self.fc_2 = nn.Linear(50, 10) + self.relu_fc_2 = nn.ReLU() + self.fc_3 = nn.Linear(10, num_labels) + + def forward(self, img, speed, hlc): + out = self.cn_1(img) + out = self.relu_1(out) + out = self.cn_2(out) + out = self.relu_2(out) + out = self.cn_3(out) + out = self.relu_3(out) + out = self.dropout_1(out) + out = out.unsqueeze(1) # add additional dimension at 1 + + _, last_states = self.clstm_n(out) + out = last_states[0][0] # 0 for layer index, 0 for h index + + # flatten & concatenate with speed + out = out.reshape(out.size(0), -1) + speed = speed.view(speed.size(0), -1) + hlc = hlc.view(hlc.size(0), -1) + out = torch.cat((out, speed, hlc), dim=1) # Concatenate the high-level commands to the rest of the inputs + out = self.fc_1(out) + out = self.relu_fc_1(out) + out = self.fc_2(out) + out = self.relu_fc_2(out) + out = self.fc_3(out) + + out = torch.sigmoid(out) + + return out + + +class PilotNetEmbeddingNoLight(nn.Module): + def __init__(self, image_shape, num_labels, num_hlc): + super(PilotNetEmbeddingNoLight, self).__init__() + self.num_channels = image_shape[2] + self.cn_1 = nn.Conv2d(self.num_channels, 8, kernel_size=3, stride=2) + self.relu_1 = nn.ReLU() + self.cn_2 = nn.Conv2d(8, 8, kernel_size=3, stride=2) + self.relu_2 = nn.ReLU() + self.cn_3 = nn.Conv2d(8, 8, kernel_size=3, stride=2) + self.relu_3 = nn.ReLU() + self.dropout_1 = nn.Dropout(0.2) + + self.clstm_n = ConvLSTM(8, 8, (5, 5), 3, batch_first=True, bias=True, return_all_layers=False) + + # Add embedding layer for high-level commands + self.embedding = nn.Embedding(num_hlc, 5) + + self.fc_1 = nn.Linear(8*35*24+1+5, 50) # add embedding layer output size + self.relu_fc_1 = nn.ReLU() + self.fc_2 = nn.Linear(50, 10) + self.relu_fc_2 = nn.ReLU() + self.fc_3 = nn.Linear(10, num_labels) + + def forward(self, img, speed, hlc): + out = self.cn_1(img) + out = self.relu_1(out) + out = self.cn_2(out) + out = self.relu_2(out) + out = self.cn_3(out) + out = self.relu_3(out) + out = self.dropout_1(out) + out = out.unsqueeze(1) # add additional dimension at 1 + + _, last_states = self.clstm_n(out) + out = last_states[0][0] # 0 for layer index, 0 for h index + + # flatten & concatenate with speed + out = out.reshape(out.size(0), -1) + speed = speed.view(speed.size(0), -1) + hlc = self.embedding(hlc).view(hlc.size(0), -1) # convert hlc to dense vectors using the embedding layer + out = torch.cat((out, speed, hlc), dim=1) # Concatenate the high-level commands to the rest of the inputs + + out = self.fc_1(out) + out = self.relu_fc_1(out) + out = self.fc_2(out) + out = self.relu_fc_2(out) + out = self.fc_3(out) + + out = torch.sigmoid(out) + + return out \ No newline at end of file diff --git a/behavior_metrics/brains/CARLA/utils/test_utils.py b/behavior_metrics/brains/CARLA/utils/test_utils.py new file mode 100644 index 00000000..1c096138 --- /dev/null +++ b/behavior_metrics/brains/CARLA/utils/test_utils.py @@ -0,0 +1,131 @@ +import torch +import numpy as np +import carla +import torch.nn.functional as F + +def model_control(model, frame_data, device='cpu', filter=True, one_hot=True, ignore_traffic_light=True, combined_control=False): + global counter + img, speed, hlc, light = preprocess_data(frame_data, filter=filter, one_hot=one_hot, ignore_traffic_light=ignore_traffic_light) + img = img.to(device) + speed = speed.to(device) + hlc = hlc.to(device) + light = light.to(device) + if ignore_traffic_light: + prediction = model(img, speed, hlc) + else: + prediction = model(img, speed, hlc, light) + prediction = prediction.detach().cpu().numpy().flatten() + #print(f"prediction: {prediction}") + + if not combined_control: + throttle, steer, brake = prediction + throttle = float(throttle) + brake = float(brake) + if brake < 0.05: brake = 0.0 + else: + combined, steer = prediction + combined = float(combined) + throttle, brake = 0.0, 0.0 + if combined >= 0.5: + throttle = (combined - 0.5) / 0.5 + else: + brake = (0.5 - combined) / 0.5 + + steer = (float(steer) * 2.0) - 1.0 + + return throttle, steer, brake + + +def preprocess_data(data, filter=True, one_hot=True, ignore_traffic_light=True): + rgb = data['rgb'].copy() + segmentation = data['segmentation'].copy() + + if filter: + rgb, segmentation = filter_classes(rgb, segmentation) + + rgb = torch.tensor(rgb, dtype=torch.float32).permute(2, 0, 1) + rgb /= 255.0 + + segmentation = torch.tensor(segmentation, dtype=torch.float32).permute(2, 0, 1) + segmentation /= 255.0 + + img = torch.cat((rgb, segmentation), dim=0) + img = img.unsqueeze(0) + + speed = torch.tensor(data['measurements'], dtype=torch.float32) + speed = torch.clamp(speed / 40, 0, 1) + speed = speed.unsqueeze(0) + + hlc = torch.tensor(data['hlc'], dtype=torch.long) + if one_hot: + hlc = F.one_hot(hlc.to(torch.int64), num_classes=4) + hlc = hlc.unsqueeze(0) + + if not ignore_traffic_light: + light = torch.tensor(data['light'], dtype=torch.long) + if one_hot: + light = F.one_hot(light.to(torch.int64), num_classes=4) + light = light.unsqueeze(0) + else: + light = None + + return img, speed, hlc, light + + +def filter_classes(rgb, seg, classes_to_keep = [1, 7, 12, 13, 14, 15, 16, 17, 18, 19, 24]): + classes = { + 0: [0, 0, 0], # Unlabeled + 1: [128, 64, 128], # Road *** + 2: [244, 35, 232], # Sidewalk + 3: [70, 70, 70], # Building + 4: [102, 102, 156], # Wall + 5: [190, 153, 153], # Fence + 6: [153, 153, 153], # Pole + 7: [250, 170, 30], # Traffic Light *** + 8: [220, 220, 0], # Traffic Sign + 9: [107, 142, 35], # Vegetation + 10: [152, 251, 152], # Terrain + 11: [70, 130, 180], # Sky + 12: [220, 20, 60], # Pedestrain *** + 13: [255, 0, 0], # Rider *** + 14: [0, 0, 142], # Car *** + 15: [0, 0, 70], # Truck *** + 16: [0, 60, 100], # Bus *** + 17: [0, 80, 100], # Train *** + 18: [0, 0, 230], # Motorcycle *** + 19: [119, 11, 32], # Bicycle *** + 20: [110, 190, 160], # Static + 21: [170, 120, 50], # Dynamic + 22: [55, 90, 80], # Other + 23: [45, 60, 150], # Water + 24: [157, 234, 50], # Road Line *** + 25: [81, 0, 81], # Ground + 26: [150, 100, 100], # Bridge + 27: [230, 150, 140], # Rail Track + 28: [180, 165, 180] # Guard Rail + } + + + classes_to_keep_rgb = np.array([classes[class_id] for class_id in classes_to_keep]) + + # Create a mask of pixels to keep + mask = np.isin(seg, classes_to_keep_rgb).all(axis=-1) + + # Initialize filtered images as black images + filtered_seg = np.zeros_like(seg) + filtered_rgb = np.zeros_like(rgb) + + # Use the mask to replace the corresponding pixels in the filtered images + filtered_seg[mask] = seg[mask] + filtered_rgb[mask] = rgb[mask] + + return filtered_rgb, filtered_seg + +def traffic_light_to_int(light_status): + light_dict = { + -1: 0, + carla.libcarla.TrafficLightState.Red: 1, + carla.libcarla.TrafficLightState.Green: 2, + carla.libcarla.TrafficLightState.Yellow: 3 + } + return light_dict[light_status] \ No newline at end of file diff --git a/behavior_metrics/configs/CARLA/CARLA_launch_files/CARLA_object_files/main_car_custom_camera.json b/behavior_metrics/configs/CARLA/CARLA_launch_files/CARLA_object_files/main_car_custom_camera.json new file mode 100644 index 00000000..5054f0fd --- /dev/null +++ b/behavior_metrics/configs/CARLA/CARLA_launch_files/CARLA_object_files/main_car_custom_camera.json @@ -0,0 +1,163 @@ +{ + "objects": + [ + { + "type": "sensor.pseudo.traffic_lights", + "id": "traffic_lights" + }, + { + "type": "sensor.pseudo.objects", + "id": "objects" + }, + { + "type": "sensor.pseudo.actor_list", + "id": "actor_list" + }, + { + "type": "sensor.pseudo.markers", + "id": "markers" + }, + { + "type": "sensor.pseudo.opendrive_map", + "id": "map" + }, + { + "type": "vehicle.tesla.model3", + "id": "ego_vehicle", + "sensors": + [ + { + "type": "sensor.camera.rgb", + "id": "rgb_front", + "spawn_point": {"x": 1.5, "y": 0.0, "z": 2.4, "roll": 0.0, "pitch": 15.0, "yaw": 0.0}, + "image_size_x": 288, + "image_size_y": 200, + "fov": 90.0 + }, + { + "type": "sensor.camera.rgb", + "id": "rgb_view", + "spawn_point": {"x": -4.5, "y": 0.0, "z": 2.8, "roll": 0.0, "pitch": 20.0, "yaw": 0.0}, + "image_size_x": 800, + "image_size_y": 600, + "fov": 90.0, + "attached_objects": + [ + { + "type": "actor.pseudo.control", + "id": "control" + } + ] + }, + { + "type": "sensor.lidar.ray_cast", + "id": "lidar", + "spawn_point": {"x": 0.0, "y": 0.0, "z": 2.4, "roll": 0.0, "pitch": 0.0, "yaw": 0.0}, + "range": 50, + "channels": 32, + "points_per_second": 320000, + "upper_fov": 2.0, + "lower_fov": -26.8, + "rotation_frequency": 20, + "noise_stddev": 0.0 + }, + { + "type": "sensor.lidar.ray_cast_semantic", + "id": "semantic_lidar", + "spawn_point": {"x": 0.0, "y": 0.0, "z": 2.4, "roll": 0.0, "pitch": 0.0, "yaw": 0.0}, + "range": 50, + "channels": 32, + "points_per_second": 320000, + "upper_fov": 2.0, + "lower_fov": -26.8, + "rotation_frequency": 20 + }, + { + "type": "sensor.other.radar", + "id": "radar_front", + "spawn_point": {"x": 2.0, "y": 0.0, "z": 2.0, "roll": 0.0, "pitch": 0.0, "yaw": 0.0}, + "horizontal_fov": 30.0, + "vertical_fov": 10.0, + "points_per_second": 1500, + "range": 100.0 + }, + { + "type": "sensor.camera.semantic_segmentation", + "id": "semantic_segmentation_front", + "spawn_point": {"x": 1.5, "y": 0.0, "z": 2.4, "roll": 0.0, "pitch": 15.0, "yaw": 0.0}, + "fov": 90.0, + "image_size_x": 288, + "image_size_y": 200 + }, + { + "type": "sensor.camera.depth", + "id": "depth_front", + "spawn_point": {"x": 2.0, "y": 0.0, "z": 2.0, "roll": 0.0, "pitch": 0.0, "yaw": 0.0}, + "fov": 90.0, + "image_size_x": 400, + "image_size_y": 70 + }, + { + "type": "sensor.camera.dvs", + "id": "dvs_front", + "spawn_point": {"x": 2.0, "y": 0.0, "z": 2.0, "roll": 0.0, "pitch": 0.0, "yaw": 0.0}, + "fov": 90.0, + "image_size_x": 400, + "image_size_y": 70, + "positive_threshold": 0.3, + "negative_threshold": 0.3, + "sigma_positive_threshold": 0.0, + "sigma_negative_threshold": 0.0, + "use_log": true, + "log_eps": 0.001 + }, + { + "type": "sensor.other.gnss", + "id": "gnss", + "spawn_point": {"x": 1.0, "y": 0.0, "z": 2.0}, + "noise_alt_stddev": 0.0, "noise_lat_stddev": 0.0, "noise_lon_stddev": 0.0, + "noise_alt_bias": 0.0, "noise_lat_bias": 0.0, "noise_lon_bias": 0.0 + }, + { + "type": "sensor.other.imu", + "id": "imu", + "spawn_point": {"x": 2.0, "y": 0.0, "z": 2.0, "roll": 0.0, "pitch": 0.0, "yaw": 0.0}, + "noise_accel_stddev_x": 0.0, "noise_accel_stddev_y": 0.0, "noise_accel_stddev_z": 0.0, + "noise_gyro_stddev_x": 0.0, "noise_gyro_stddev_y": 0.0, "noise_gyro_stddev_z": 0.0, + "noise_gyro_bias_x": 0.0, "noise_gyro_bias_y": 0.0, "noise_gyro_bias_z": 0.0 + }, + { + "type": "sensor.other.collision", + "id": "collision", + "spawn_point": {"x": 2.5, "y": 0.0, "z": 0.7} + }, + { + "type": "sensor.other.lane_invasion", + "id": "lane_invasion", + "spawn_point": {"x": 0.0, "y": 0.0, "z": 0.0} + }, + { + "type": "sensor.pseudo.tf", + "id": "tf" + }, + { + "type": "sensor.pseudo.objects", + "id": "objects" + }, + { + "type": "sensor.pseudo.odom", + "id": "odometry" + }, + { + "type": "sensor.pseudo.speedometer", + "id": "speedometer" + }, + { + "type": "actor.pseudo.control", + "id": "control" + } + ] + } + ] +} + diff --git a/behavior_metrics/configs/CARLA/CARLA_launch_files/town_02_anticlockwise_imitation_learning.launch b/behavior_metrics/configs/CARLA/CARLA_launch_files/town_02_anticlockwise_imitation_learning.launch new file mode 100644 index 00000000..bbdac71d --- /dev/null +++ b/behavior_metrics/configs/CARLA/CARLA_launch_files/town_02_anticlockwise_imitation_learning.launch @@ -0,0 +1,54 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/behavior_metrics/configs/CARLA/default_carla_imitation_learning.yml b/behavior_metrics/configs/CARLA/default_carla_imitation_learning.yml new file mode 100644 index 00000000..612b4c10 --- /dev/null +++ b/behavior_metrics/configs/CARLA/default_carla_imitation_learning.yml @@ -0,0 +1,70 @@ +Behaviors: + Robot: + Sensors: + Cameras: + Camera_0: + Name: 'camera_0' + Topic: '/carla/ego_vehicle/rgb_front/image' + Camera_1: + Name: 'camera_1' + Topic: '/carla/ego_vehicle/rgb_view/image' + Camera_2: + Name: 'camera_2' + Topic: '/carla/ego_vehicle/semantic_segmentation_front/image' + Camera_3: + Name: 'camera_3' + Topic: '/carla/ego_vehicle/dvs_front/image' + Pose3D: + Pose3D_0: + Name: 'pose3d_0' + Topic: '/carla/ego_vehicle/odometry' + BirdEyeView: + BirdEyeView_0: + Name: 'bird_eye_view_0' + Topic: '' + Speedometer: + Speedometer_0: + Name: 'speedometer_0' + Topic: '/carla/ego_vehicle/speedometer' + Actuators: + CARLA_Motors: + Motors_0: + Name: 'motors_0' + Topic: '/carla/ego_vehicle/vehicle_control_cmd' + MaxV: 3 + MaxW: 0.3 + BrainPath: 'brains/CARLA/brain_carla_segmentation_based_imitation_learning.py' + PilotTimeCycle: 50 + AsyncMode: False + Parameters: + Model: 'pilotnet_v8.0.pth' + ImageCropped: False + ImageSize: [ 100,50 ] + ImageNormalized: True + PredictionsNormalized: True + GPU: True + UseOptimized: False + ImageTranform: '' + Type: 'CARLA' + Simulation: + World: configs/CARLA/CARLA_launch_files/town_02_anticlockwise_imitation_learning.launch + RandomSpawnPoint: False + NumberOfVehicle: 50 + NumberOfWalker: 0 + PercentagePedestriansRunning: 0.5 + PercentagePedestriansCrossing: 0.5 + Dataset: + In: '/tmp/my_bag.bag' + Out: '' + Stats: + Out: './' + PerfectLap: './perfect_bags/lap-simple-circuit.bag' + Layout: + Frame_0: + Name: frame_0 + Geometry: [0, 1, 1, 2] + Data: rgbimage + Frame_1: + Name: frame_1 + Geometry: [1, 1, 1, 2] + Data: rgbimage \ No newline at end of file diff --git a/behavior_metrics/models/CARLA/pilotnet_combined_control.pth b/behavior_metrics/models/CARLA/pilotnet_combined_control.pth new file mode 100644 index 00000000..d3ff6bd6 Binary files /dev/null and b/behavior_metrics/models/CARLA/pilotnet_combined_control.pth differ diff --git a/behavior_metrics/models/CARLA/pilotnet_v8.0.pth b/behavior_metrics/models/CARLA/pilotnet_v8.0.pth new file mode 100644 index 00000000..9149d3ad Binary files /dev/null and b/behavior_metrics/models/CARLA/pilotnet_v8.0.pth differ diff --git a/behavior_metrics/utils/controller_carla.py b/behavior_metrics/utils/controller_carla.py index 5290a7ff..408033e6 100644 --- a/behavior_metrics/utils/controller_carla.py +++ b/behavior_metrics/utils/controller_carla.py @@ -72,7 +72,7 @@ def __init__(self): client = carla.Client('localhost', 2000) client.set_timeout(100.0) # seconds self.world = client.get_world() - time.sleep(5) + # time.sleep(5) self.carla_map = self.world.get_map() while len(self.world.get_actors().filter('vehicle.*')) == 0: logger.info("Waiting for vehicles!")