Skip to content

Commit

Permalink
[Uplift 1.62.x] #21408 AI Chat retrieves page text whenever a new hum…
Browse files Browse the repository at this point in the history
…an message is submitted (#21612)

* AI Chat retrieves page text whenever a new human message is submitted (#21408)

* AI Chat retrieves page text whenever a new human message is submitted

On a page-connected conversation, every message becomes pending until page content is (re-)fetched.
Also fixes bug for same-document navigations where class variables were being hidden by subclass variables left over from the iOS refactor.

* Generating page content becomes inline operation with callback and provides content caching

* Only fetch page content once at a time, other callers wait for current operation to complete

* Update iOS subclasses to be compatible with ConversationDriver class changes

* AI Chat: fix location bar submitted entry not submitting if request in progress

* needed getvisibleconversationhistory from 1.63.x

* remove extra brace from merge

* format
  • Loading branch information
petemill authored Jan 22, 2024
1 parent 0d7c0ba commit 5343488
Show file tree
Hide file tree
Showing 12 changed files with 356 additions and 264 deletions.
30 changes: 7 additions & 23 deletions browser/ui/webui/ai_chat/ai_chat_ui_page_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void AIChatUIPageHandler::SetClientPage(
// ex. A user may ask a question from the location bar
if (active_chat_tab_helper_ &&
active_chat_tab_helper_->HasPendingConversationEntry()) {
OnConversationEntryPending();
OnHistoryUpdate();
}
}

Expand All @@ -124,10 +124,12 @@ void AIChatUIPageHandler::ChangeModel(const std::string& model_key) {

void AIChatUIPageHandler::SubmitHumanConversationEntry(
const std::string& input) {
DCHECK(!active_chat_tab_helper_->IsRequestInProgress())
<< "Should not be able to submit more"
<< "than a single human conversation turn at a time.";
mojom::ConversationTurn turn = {CharacterType::HUMAN,
ConversationTurnVisibility::VISIBLE, input};
active_chat_tab_helper_->MakeAPIRequestWithConversationHistoryUpdate(
std::move(turn));
active_chat_tab_helper_->SubmitHumanConversationEntry(std::move(turn));
}

void AIChatUIPageHandler::SubmitSummarizationRequest() {
Expand All @@ -142,21 +144,9 @@ void AIChatUIPageHandler::GetConversationHistory(
std::move(callback).Run({});
return;
}
std::vector<ConversationTurn> history =
active_chat_tab_helper_->GetConversationHistory();

std::vector<ai_chat::mojom::ConversationTurnPtr> list;

// Remove conversations that are meant to be hidden from the user
auto new_end_it = std::remove_if(
history.begin(), history.end(), [](const ConversationTurn& turn) {
return turn.visibility == ConversationTurnVisibility::HIDDEN;
});

std::transform(history.begin(), new_end_it, std::back_inserter(list),
[](const ConversationTurn& turn) { return turn.Clone(); });

std::move(callback).Run(std::move(list));
std::move(callback).Run(
active_chat_tab_helper_->GetVisibleConversationHistory());
}

void AIChatUIPageHandler::GetSuggestedQuestions(
Expand Down Expand Up @@ -412,12 +402,6 @@ void AIChatUIPageHandler::OnPageHasContent(mojom::SiteInfoPtr site_info) {
}
}

void AIChatUIPageHandler::OnConversationEntryPending() {
if (page_.is_bound()) {
page_->OnConversationEntryPending();
}
}

void AIChatUIPageHandler::GetFaviconImageData(
GetFaviconImageDataCallback callback) {
if (!active_chat_tab_helper_) {
Expand Down
1 change: 0 additions & 1 deletion browser/ui/webui/ai_chat/ai_chat_ui_page_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ class AIChatUIPageHandler : public ai_chat::mojom::PageHandler,
mojom::SuggestionGenerationStatus suggestion_generation_status) override;
void OnFaviconImageDataChanged() override;
void OnPageHasContent(mojom::SiteInfoPtr site_info) override;
void OnConversationEntryPending() override;

void GetFaviconImageData(GetFaviconImageDataCallback callback) override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void ChromeAutocompleteProviderClient::OpenLeo(const std::u16string& query) {
ai_chat::mojom::CharacterType::HUMAN,
ai_chat::mojom::ConversationTurnVisibility::VISIBLE,
base::UTF16ToUTF8(query)};
chat_tab_helper->MakeAPIRequestWithConversationHistoryUpdate(std::move(turn));
chat_tab_helper->SubmitHumanConversationEntry(std::move(turn));
ai_chat::AIChatMetrics* metrics =
g_brave_browser_process->process_misc_metrics()->ai_chat_metrics();
CHECK(metrics);
Expand Down
56 changes: 19 additions & 37 deletions components/ai_chat/content/browser/ai_chat_tab_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,56 +58,45 @@ AIChatTabHelper::~AIChatTabHelper() = default;

// content::WebContentsObserver

void AIChatTabHelper::DocumentOnLoadCompletedInPrimaryMainFrame() {
// We might have content here, so check.
// TODO(petemill): If there are other navigation events to also
// check if content is available at, then start a queue and make
// sure we don't have multiple async distills going on at the same time.
MaybeGeneratePageText();
}

void AIChatTabHelper::WebContentsDestroyed() {
CleanUp();
favicon::ContentFaviconDriver::FromWebContents(web_contents())
->RemoveObserver(this);
}

void AIChatTabHelper::DidFinishNavigation(
content::NavigationHandle* navigation_handle) {
// Store current navigation ID of the main document
// so that we can ignore async responses against any navigated-away-from
// documents.
if (!navigation_handle->IsInMainFrame()) {
DVLOG(3) << "FinishNavigation NOT in main frame";
return;
}
DVLOG(2) << __func__ << navigation_handle->GetNavigationId()
<< " url: " << navigation_handle->GetURL().spec()
<< " same document? " << navigation_handle->IsSameDocument();
SetNavigationId(navigation_handle->GetNavigationId());

// Allow same-document navigation, as content often changes as a result
// of framgment / pushState / replaceState navigations.
// Content won't be retrieved immediately and we don't have a similar
// "DOM Content Loaded" event, so let's wait for something else such as
// page title changing, or a timer completing before calling
// |MaybeGeneratePageText|.
SetSameDocumentNavigation(navigation_handle->IsSameDocument());
// Experimentally only call |CleanUp| _if_ a same-page navigation
// results in a page title change (see |TtileWasSet|).
if (!IsSameDocumentNavigation()) {
CleanUp();
// page title changing before committing to starting a new conversation
// and treating it as a "fresh page".
is_same_document_navigation_ = navigation_handle->IsSameDocument();
pending_navigation_id_ = navigation_handle->GetNavigationId();
// Experimentally only call |OnNewPage| for same-page navigations _if_
// it results in a page title change (see |TtileWasSet|).
if (!is_same_document_navigation_) {
OnNewPage(pending_navigation_id_);
}
}

void AIChatTabHelper::TitleWasSet(content::NavigationEntry* entry) {
DVLOG(3) << __func__ << entry->GetTitle();
if (is_same_document_navigation_) {
// Seems as good a time as any to check for content after a same-document
// navigation.
// We only perform CleanUp here in case it was a minor pushState / fragment
// navigation and didn't result in new content.
CleanUp();
MaybeGeneratePageText();
DVLOG(3) << "Same document navigation detected new \"page\" - calling "
"OnNewPage()";
// Page title modification after same-document navigation seems as good a
// time as any to assume meaningful changes occured to the content.
OnNewPage(pending_navigation_id_);
// Don't respond to further TitleWasSet
is_same_document_navigation_ = false;
}
}

Expand All @@ -129,16 +118,9 @@ GURL AIChatTabHelper::GetPageURL() const {
}

void AIChatTabHelper::GetPageContent(
base::OnceCallback<void(std::string, bool is_video)> callback) const {
FetchPageContent(web_contents(), std::move(callback));
}

bool AIChatTabHelper::HasPrimaryMainFrame() const {
return web_contents()->GetPrimaryMainFrame() != nullptr;
}

bool AIChatTabHelper::IsDocumentOnLoadCompletedInPrimaryMainFrame() const {
return web_contents()->IsDocumentOnLoadCompletedInPrimaryMainFrame();
GetPageContentCallback callback,
std::string_view invalidation_token) const {
FetchPageContent(web_contents(), invalidation_token, std::move(callback));
}

std::u16string AIChatTabHelper::GetPageTitle() const {
Expand Down
11 changes: 3 additions & 8 deletions components/ai_chat/content/browser/ai_chat_tab_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class AIChatTabHelper : public content::WebContentsObserver,
PrefService* local_state_prefs);

// content::WebContentsObserver
void DocumentOnLoadCompletedInPrimaryMainFrame() override;
void WebContentsDestroyed() override;
void DidFinishNavigation(
content::NavigationHandle* navigation_handle) override;
Expand All @@ -63,18 +62,14 @@ class AIChatTabHelper : public content::WebContentsObserver,

// ai_chat::ConversationDriver
GURL GetPageURL() const override;
void GetPageContent(base::OnceCallback<void(std::string, bool is_video)>
callback) const override;
bool HasPrimaryMainFrame() const override;
bool IsDocumentOnLoadCompletedInPrimaryMainFrame() const override;
void GetPageContent(GetPageContentCallback callback,
std::string_view invalidation_token) const override;
std::u16string GetPageTitle() const override;

raw_ptr<AIChatMetrics> ai_chat_metrics_;

// Store the unique ID for each navigation so that
// we can ignore API responses for previous navigations.
int64_t current_navigation_id_;
bool is_same_document_navigation_ = false;
int64_t pending_navigation_id_;

base::WeakPtrFactory<AIChatTabHelper> weak_ptr_factory_{this};
WEB_CONTENTS_USER_DATA_KEY_DECL();
Expand Down
43 changes: 31 additions & 12 deletions components/ai_chat/content/browser/page_content_fetcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ net::NetworkTrafficAnnotationTag GetNetworkTrafficAnnotationTag() {
class PageContentFetcher {
public:
void Start(mojo::Remote<mojom::PageContentExtractor> content_extractor,
std::string_view invalidation_token,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
FetchPageContentCallback callback) {
url_loader_factory_ = url_loader_factory;
Expand All @@ -70,22 +71,24 @@ class PageContentFetcher {
// after it is destroyed.
content_extractor_.set_disconnect_handler(base::BindOnce(
&PageContentFetcher::DeleteSelf, base::Unretained(this)));
content_extractor_->ExtractPageContent(
base::BindOnce(&PageContentFetcher::OnTabContentResult,
base::Unretained(this), std::move(callback)));
content_extractor_->ExtractPageContent(base::BindOnce(
&PageContentFetcher::OnTabContentResult, base::Unretained(this),
std::move(callback), invalidation_token));
}

private:
void DeleteSelf() { delete this; }

void SendResultAndDeleteSelf(FetchPageContentCallback callback,
std::string content = "",
std::string invalidation_token = "",
bool is_video = false) {
std::move(callback).Run(content, is_video);
std::move(callback).Run(content, is_video, invalidation_token);
delete this;
}

void OnTabContentResult(FetchPageContentCallback callback,
std::string_view invalidation_token,
mojom::PageContentPtr data) {
if (!data) {
VLOG(1) << __func__ << " no data.";
Expand All @@ -101,7 +104,7 @@ class PageContentFetcher {
auto content = data->content->get_content();
DVLOG(1) << __func__ << ": Got content with char length of "
<< content.length();
SendResultAndDeleteSelf(std::move(callback), content, false);
SendResultAndDeleteSelf(std::move(callback), content, "", false);
return;
}
// If it's video, we expect content url
Expand All @@ -110,7 +113,16 @@ class PageContentFetcher {
if (content_url.is_empty() || !content_url.is_valid() ||
!content_url.SchemeIs(url::kHttpsScheme)) {
VLOG(1) << "Invalid content_url";
SendResultAndDeleteSelf(std::move(callback), "", true);
SendResultAndDeleteSelf(std::move(callback), "", "", true);
return;
}
// Subsequent calls do not need to re-fetch if the url stays the same
auto new_invalidation_token = content_url.spec();
if (new_invalidation_token == invalidation_token) {
VLOG(2) << "Not fetching content since invalidation token matches: "
<< invalidation_token;
SendResultAndDeleteSelf(std::move(callback), "", new_invalidation_token,
true);
return;
}
DVLOG(1) << "Making video transcript fetch to " << content_url.spec();
Expand All @@ -132,13 +144,14 @@ class PageContentFetcher {
auto on_response =
base::BindOnce(&PageContentFetcher::OnTranscriptFetchResponse,
weak_ptr_factory_.GetWeakPtr(), std::move(callback),
std::move(loader), is_youtube);
std::move(loader), is_youtube, new_invalidation_token);
loader_ptr->DownloadToString(url_loader_factory_.get(),
std::move(on_response), 2 * 1024 * 1024);
}

void OnYoutubeTranscriptXMLParsed(
FetchPageContentCallback callback,
std::string invalidation_token,
base::expected<base::Value, std::string> result) {
// Example Youtube transcript XML:
//
Expand Down Expand Up @@ -182,13 +195,15 @@ class PageContentFetcher {
transcript_text += text;
}

SendResultAndDeleteSelf(std::move(callback), transcript_text, true);
SendResultAndDeleteSelf(std::move(callback), transcript_text,
invalidation_token, true);
}

void OnTranscriptFetchResponse(
FetchPageContentCallback callback,
std::unique_ptr<network::SimpleURLLoader> loader,
bool is_youtube,
std::string invalidation_token,
std::unique_ptr<std::string> response_body) {
auto response_code = -1;
base::flat_map<std::string, std::string> headers;
Expand All @@ -215,11 +230,13 @@ class PageContentFetcher {
data_decoder::mojom::XmlParser::WhitespaceBehavior::
kPreserveSignificant,
base::BindOnce(&PageContentFetcher::OnYoutubeTranscriptXMLParsed,
weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
weak_ptr_factory_.GetWeakPtr(), std::move(callback),
invalidation_token));
return;
}

SendResultAndDeleteSelf(std::move(callback), transcript_content, true);
SendResultAndDeleteSelf(std::move(callback), transcript_content,
invalidation_token, true);
}

scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_;
Expand All @@ -230,6 +247,7 @@ class PageContentFetcher {
} // namespace

void FetchPageContent(content::WebContents* web_contents,
std::string_view invalidation_token,
FetchPageContentCallback callback) {
VLOG(2) << __func__ << " Extracting page content from renderer...";

Expand All @@ -240,7 +258,7 @@ void FetchPageContent(content::WebContents* web_contents,
LOG(ERROR)
<< "Content extraction request submitted for a WebContents without "
"a primary main frame";
std::move(callback).Run("", false);
std::move(callback).Run("", false, "");
return;
}

Expand All @@ -255,7 +273,8 @@ void FetchPageContent(content::WebContents* web_contents,
->GetDefaultStoragePartition()
->GetURLLoaderFactoryForBrowserProcess()
.get();
fetcher->Start(std::move(extractor), loader, std::move(callback));
fetcher->Start(std::move(extractor), invalidation_token, loader,
std::move(callback));
}

} // namespace ai_chat
6 changes: 4 additions & 2 deletions components/ai_chat/content/browser/page_content_fetcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ class WebContents;
namespace ai_chat {

using FetchPageContentCallback =
base::OnceCallback<void(std::string, bool is_video)>;

base::OnceCallback<void(std::string page_content,
bool is_video,
std::string invalidation_token)>;
void FetchPageContent(content::WebContents* web_contents,
std::string_view invalidation_token,
FetchPageContentCallback callback);

} // namespace ai_chat
Expand Down
Loading

0 comments on commit 5343488

Please sign in to comment.