-
Notifications
You must be signed in to change notification settings - Fork 50
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #632 from Qi-Zha0/noetic-devel
Added segmentation based imitation learning model as a new brain
- Loading branch information
Showing
11 changed files
with
1,082 additions
and
1 deletion.
There are no files selected for viewing
152 changes: 152 additions & 0 deletions
152
behavior_metrics/brains/CARLA/brain_carla_segmentation_based_imitation_learning.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
from torchvision import transforms | ||
from PIL import Image | ||
from brains.CARLA.utils.pilotnet_onehot import PilotNetOneHot | ||
from brains.CARLA.utils.test_utils import traffic_light_to_int, model_control | ||
from utils.constants import PRETRAINED_MODELS_DIR, ROOT_PATH | ||
from os import path | ||
|
||
import numpy as np | ||
|
||
import torch | ||
import torchvision | ||
import cv2 | ||
import time | ||
import os | ||
import math | ||
import carla | ||
|
||
PRETRAINED_MODELS = ROOT_PATH + '/' + PRETRAINED_MODELS_DIR + 'CARLA/' | ||
|
||
class Brain: | ||
|
||
def __init__(self, sensors, actuators, model=None, handler=None, config=None): | ||
self.motors = actuators.get_motor('motors_0') | ||
self.camera_rgb = sensors.get_camera('camera_0') # rgb front view camera | ||
self.camera_seg = sensors.get_camera('camera_2') # segmentation camera | ||
self.handler = handler | ||
self.inference_times = [] | ||
self.gpu_inference = config['GPU'] | ||
self.device = torch.device('cuda' if (torch.cuda.is_available() and self.gpu_inference) else 'cpu') | ||
|
||
client = carla.Client('localhost', 2000) | ||
client.set_timeout(10.0) | ||
world = client.get_world() | ||
self.map = world.get_map() | ||
|
||
weather = carla.WeatherParameters.ClearNoon | ||
world.set_weather(weather) | ||
|
||
self.vehicle = None | ||
while self.vehicle is None: | ||
for vehicle in world.get_actors().filter('vehicle.*'): | ||
if vehicle.attributes.get('role_name') == 'ego_vehicle': | ||
self.vehicle = vehicle | ||
break | ||
if self.vehicle is None: | ||
print("Waiting for vehicle with role_name 'ego_vehicle'") | ||
time.sleep(1) # sleep for 1 second before checking again | ||
|
||
if model: | ||
if not path.exists(PRETRAINED_MODELS + model): | ||
print("File " + model + " cannot be found in " + PRETRAINED_MODELS) | ||
|
||
if config['UseOptimized']: | ||
self.net = torch.jit.load(PRETRAINED_MODELS + model).to(self.device) | ||
else: | ||
self.net = PilotNetOneHot((288, 200, 6), 3, 4, 4).to(self.device) | ||
self.net.load_state_dict(torch.load(PRETRAINED_MODELS + model,map_location=self.device)) | ||
self.net.eval() | ||
|
||
self.prev_hlc = 0 | ||
|
||
|
||
def update_frame(self, frame_id, data): | ||
"""Update the information to be shown in one of the GUI's frames. | ||
Arguments: | ||
frame_id {str} -- Id of the frame that will represent the data | ||
data {*} -- Data to be shown in the frame. Depending on the type of frame (rgbimage, laser, pose3d, etc) | ||
""" | ||
if data.shape[0] != data.shape[1]: | ||
if data.shape[0] > data.shape[1]: | ||
difference = data.shape[0] - data.shape[1] | ||
extra_left, extra_right = int(difference/2), int(difference/2) | ||
extra_top, extra_bottom = 0, 0 | ||
else: | ||
difference = data.shape[1] - data.shape[0] | ||
extra_left, extra_right = 0, 0 | ||
extra_top, extra_bottom = int(difference/2), int(difference/2) | ||
|
||
|
||
data = np.pad(data, ((extra_top, extra_bottom), (extra_left, extra_right), (0, 0)), mode='constant', constant_values=0) | ||
|
||
self.handler.update_frame(frame_id, data) | ||
|
||
def execute(self): | ||
"""Main loop of the brain. This will be called iteratively each TIME_CYCLE (see pilot.py)""" | ||
|
||
rgb_image = self.camera_rgb.getImage().data | ||
seg_image = self.camera_seg.getImage().data | ||
|
||
self.update_frame('frame_0', rgb_image) | ||
self.update_frame('frame_1', seg_image) | ||
|
||
try: | ||
# calculate speed | ||
speed_m_s = self.vehicle.get_velocity() | ||
speed = 3.6 * math.sqrt(speed_m_s.x**2 + speed_m_s.y**2 + speed_m_s.z**2) | ||
|
||
# randomly choose high-level command if at junction | ||
vehicle_location = self.vehicle.get_transform().location | ||
vehicle_waypoint = self.map.get_waypoint(vehicle_location) | ||
next_to_junction = False | ||
for j in range(1, 11): | ||
next_waypoint = vehicle_waypoint.next(j * 1.0)[0] | ||
if next_waypoint.is_junction: | ||
next_to_junction = True | ||
next_waypoints = vehicle_waypoint.next(j * 1.0) | ||
break | ||
if vehicle_waypoint.is_junction or next_to_junction: | ||
if self.prev_hlc == 0: | ||
valid_turns = [] | ||
for next_wp in next_waypoints: | ||
yaw_diff = next_wp.transform.rotation.yaw - vehicle_waypoint.transform.rotation.yaw | ||
yaw_diff = (yaw_diff + 180) % 360 - 180 | ||
if -15 < yaw_diff < 15: | ||
valid_turns.append(3) # Go Straight | ||
elif 15 < yaw_diff < 165: | ||
valid_turns.append(1) # Turn Left | ||
elif -165 < yaw_diff < -15: | ||
valid_turns.append(2) # Turn Right | ||
hlc = np.random.choice(valid_turns) | ||
else: | ||
hlc = self.prev_hlc | ||
else: | ||
hlc = 0 | ||
|
||
# get traffic light status | ||
light_status = -1 | ||
if self.vehicle.is_at_traffic_light(): | ||
traffic_light = self.vehicle.get_traffic_light() | ||
light_status = traffic_light.get_state() | ||
|
||
print(f'hlc: {hlc}') | ||
print(f'light: {light_status}') | ||
frame_data = { | ||
'hlc': hlc, | ||
'measurements': speed, | ||
'rgb': np.copy(rgb_image), | ||
'segmentation': np.copy(seg_image), | ||
'light': np.array([traffic_light_to_int(light_status)]) | ||
} | ||
|
||
throttle, steer, brake = model_control(self.net, | ||
frame_data, | ||
ignore_traffic_light=False, | ||
device=self.device, | ||
combined_control=False) | ||
self.motors.sendThrottle(throttle) | ||
self.motors.sendSteer(steer) | ||
self.motors.sendBrake(brake) | ||
except Exception as err: | ||
print(err) |
158 changes: 158 additions & 0 deletions
158
behavior_metrics/brains/CARLA/pytorch/utils/pilotnet_onehot.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
import torch | ||
import torch.nn as nn | ||
import sys | ||
import os | ||
|
||
from .convlstm import ConvLSTM | ||
|
||
|
||
class PilotNetOneHot(nn.Module): | ||
def __init__(self, image_shape, num_labels, num_hlc, num_light): | ||
super(PilotNetOneHot, self).__init__() | ||
self.num_channels = image_shape[2] | ||
self.cn_1 = nn.Conv2d(self.num_channels, 8, kernel_size=3, stride=2) | ||
self.relu_1 = nn.ReLU() | ||
self.cn_2 = nn.Conv2d(8, 8, kernel_size=3, stride=2) | ||
self.relu_2 = nn.ReLU() | ||
self.cn_3 = nn.Conv2d(8, 8, kernel_size=3, stride=2) | ||
self.relu_3 = nn.ReLU() | ||
self.dropout_1 = nn.Dropout(0.2) | ||
|
||
self.clstm_n = ConvLSTM(8, 8, (5, 5), 3, batch_first=True, bias=True, return_all_layers=False) | ||
|
||
self.fc_1 = nn.Linear(8*35*24+1+num_hlc+num_light, 50) # add embedding layer output size | ||
self.relu_fc_1 = nn.ReLU() | ||
self.fc_2 = nn.Linear(50, 10) | ||
self.relu_fc_2 = nn.ReLU() | ||
self.fc_3 = nn.Linear(10, num_labels) | ||
|
||
def forward(self, img, speed, hlc, light): | ||
out = self.cn_1(img) | ||
out = self.relu_1(out) | ||
out = self.cn_2(out) | ||
out = self.relu_2(out) | ||
out = self.cn_3(out) | ||
out = self.relu_3(out) | ||
out = self.dropout_1(out) | ||
out = out.unsqueeze(1) # add additional dimension at 1 | ||
|
||
_, last_states = self.clstm_n(out) | ||
out = last_states[0][0] # 0 for layer index, 0 for h index | ||
|
||
# flatten & concatenate with speed | ||
out = out.reshape(out.size(0), -1) | ||
speed = speed.view(speed.size(0), -1) | ||
hlc = hlc.view(hlc.size(0), -1) | ||
light = light.view(light.size(0), -1) | ||
out = torch.cat((out, speed, hlc, light), dim=1) # Concatenate the high-level commands to the rest of the inputs | ||
|
||
out = self.fc_1(out) | ||
out = self.relu_fc_1(out) | ||
out = self.fc_2(out) | ||
out = self.relu_fc_2(out) | ||
out = self.fc_3(out) | ||
|
||
out = torch.sigmoid(out) | ||
|
||
return out | ||
|
||
class PilotNetOneHotNoLight(nn.Module): | ||
def __init__(self, image_shape, num_labels, num_hlc): | ||
super(PilotNetOneHotNoLight, self).__init__() | ||
self.num_channels = image_shape[2] | ||
self.cn_1 = nn.Conv2d(self.num_channels, 8, kernel_size=3, stride=2) | ||
self.relu_1 = nn.ReLU() | ||
self.cn_2 = nn.Conv2d(8, 8, kernel_size=3, stride=2) | ||
self.relu_2 = nn.ReLU() | ||
self.cn_3 = nn.Conv2d(8, 8, kernel_size=3, stride=2) | ||
self.relu_3 = nn.ReLU() | ||
self.dropout_1 = nn.Dropout(0.2) | ||
|
||
self.clstm_n = ConvLSTM(8, 8, (5, 5), 3, batch_first=True, bias=True, return_all_layers=False) | ||
|
||
self.fc_1 = nn.Linear(8*35*24+1+num_hlc, 50) # add embedding layer output size | ||
self.relu_fc_1 = nn.ReLU() | ||
self.fc_2 = nn.Linear(50, 10) | ||
self.relu_fc_2 = nn.ReLU() | ||
self.fc_3 = nn.Linear(10, num_labels) | ||
|
||
def forward(self, img, speed, hlc): | ||
out = self.cn_1(img) | ||
out = self.relu_1(out) | ||
out = self.cn_2(out) | ||
out = self.relu_2(out) | ||
out = self.cn_3(out) | ||
out = self.relu_3(out) | ||
out = self.dropout_1(out) | ||
out = out.unsqueeze(1) # add additional dimension at 1 | ||
|
||
_, last_states = self.clstm_n(out) | ||
out = last_states[0][0] # 0 for layer index, 0 for h index | ||
|
||
# flatten & concatenate with speed | ||
out = out.reshape(out.size(0), -1) | ||
speed = speed.view(speed.size(0), -1) | ||
hlc = hlc.view(hlc.size(0), -1) | ||
out = torch.cat((out, speed, hlc), dim=1) # Concatenate the high-level commands to the rest of the inputs | ||
out = self.fc_1(out) | ||
out = self.relu_fc_1(out) | ||
out = self.fc_2(out) | ||
out = self.relu_fc_2(out) | ||
out = self.fc_3(out) | ||
|
||
out = torch.sigmoid(out) | ||
|
||
return out | ||
|
||
|
||
class PilotNetEmbeddingNoLight(nn.Module): | ||
def __init__(self, image_shape, num_labels, num_hlc): | ||
super(PilotNetEmbeddingNoLight, self).__init__() | ||
self.num_channels = image_shape[2] | ||
self.cn_1 = nn.Conv2d(self.num_channels, 8, kernel_size=3, stride=2) | ||
self.relu_1 = nn.ReLU() | ||
self.cn_2 = nn.Conv2d(8, 8, kernel_size=3, stride=2) | ||
self.relu_2 = nn.ReLU() | ||
self.cn_3 = nn.Conv2d(8, 8, kernel_size=3, stride=2) | ||
self.relu_3 = nn.ReLU() | ||
self.dropout_1 = nn.Dropout(0.2) | ||
|
||
self.clstm_n = ConvLSTM(8, 8, (5, 5), 3, batch_first=True, bias=True, return_all_layers=False) | ||
|
||
# Add embedding layer for high-level commands | ||
self.embedding = nn.Embedding(num_hlc, 5) | ||
|
||
self.fc_1 = nn.Linear(8*35*24+1+5, 50) # add embedding layer output size | ||
self.relu_fc_1 = nn.ReLU() | ||
self.fc_2 = nn.Linear(50, 10) | ||
self.relu_fc_2 = nn.ReLU() | ||
self.fc_3 = nn.Linear(10, num_labels) | ||
|
||
def forward(self, img, speed, hlc): | ||
out = self.cn_1(img) | ||
out = self.relu_1(out) | ||
out = self.cn_2(out) | ||
out = self.relu_2(out) | ||
out = self.cn_3(out) | ||
out = self.relu_3(out) | ||
out = self.dropout_1(out) | ||
out = out.unsqueeze(1) # add additional dimension at 1 | ||
|
||
_, last_states = self.clstm_n(out) | ||
out = last_states[0][0] # 0 for layer index, 0 for h index | ||
|
||
# flatten & concatenate with speed | ||
out = out.reshape(out.size(0), -1) | ||
speed = speed.view(speed.size(0), -1) | ||
hlc = self.embedding(hlc).view(hlc.size(0), -1) # convert hlc to dense vectors using the embedding layer | ||
out = torch.cat((out, speed, hlc), dim=1) # Concatenate the high-level commands to the rest of the inputs | ||
|
||
out = self.fc_1(out) | ||
out = self.relu_fc_1(out) | ||
out = self.fc_2(out) | ||
out = self.relu_fc_2(out) | ||
out = self.fc_3(out) | ||
|
||
out = torch.sigmoid(out) | ||
|
||
return out |
Oops, something went wrong.