Skip to content

Commit

Permalink
add option to only compute embeddings of top ranking sites.
Browse files Browse the repository at this point in the history
this is not really ideal, but it turns out to be way too slow to compute
the embeddings for all the sites in the index. this way, we at least get embeddings
for the sites that are most likely to appear in the search results while it is
still tractable to compute.
  • Loading branch information
mikkeldenker committed Mar 7, 2024
1 parent 66e3ee7 commit 40a9260
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 65 deletions.
7 changes: 6 additions & 1 deletion crates/core/examples/indexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@ fn main() -> anyhow::Result<()> {
safety_classifier_path: None,
minimum_clean_words: None,
batch_size: 512,
dual_encoder_model_path: args.dual_encoder_path,
dual_encoder: args
.dual_encoder_path
.map(|p| stract::config::IndexingDualEncoderConfig {
model_path: p,
page_centrality_rank_threshold: Some(1_000_000),
}),
})?;

println!("Indexing took {:?}", start.elapsed());
Expand Down
11 changes: 10 additions & 1 deletion crates/core/src/config/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,16 @@ pub struct IndexingLocalConfig {
#[serde(default = "defaults::Indexing::batch_size")]
pub batch_size: usize,

pub dual_encoder_model_path: Option<String>,
pub dual_encoder: Option<IndexingDualEncoderConfig>,
}

#[derive(Debug, Deserialize, Clone)]
pub struct IndexingDualEncoderConfig {
pub model_path: String,

/// Only compute embeddings for pages that has a
/// centrality rank less than this threshold
pub page_centrality_rank_threshold: Option<u64>,
}

#[derive(Debug, Deserialize, Clone)]
Expand Down
10 changes: 6 additions & 4 deletions crates/core/src/entrypoint/centrality.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ impl Centrality {
}
store.flush();

let rank_store = RocksDbStore::open(base_output.as_ref().join("harmonic_rank"));
let rank_store: RocksDbStore<crate::webgraph::NodeID, u64> =
RocksDbStore::open(base_output.as_ref().join("harmonic_rank"));
let mut top_harmonics = Vec::new();
for (rank, node, centrality) in ExternalSorter::new()
.with_chunk_size(100_000_000)
Expand All @@ -78,7 +79,7 @@ impl Centrality {
(rank, node_id, centrality)
})
{
rank_store.insert(node, rank as f64);
rank_store.insert(node, rank as u64);

if top_harmonics.len() < 1_000_000 {
top_harmonics.push((graph.id2node(&node).unwrap(), centrality));
Expand Down Expand Up @@ -110,7 +111,8 @@ impl Centrality {
let graph = WebgraphBuilder::new(webgraph_path).single_threaded().open();

let approx = ApproxHarmonic::build(&graph, base_output.as_ref().join("approx_harmonic"));
let approx_rank = RocksDbStore::open(base_output.as_ref().join("approx_harmonic_rank"));
let approx_rank: RocksDbStore<crate::webgraph::NodeID, u64> =
RocksDbStore::open(base_output.as_ref().join("approx_harmonic_rank"));

let mut top_nodes = Vec::new();

Expand All @@ -126,7 +128,7 @@ impl Centrality {
(rank, node_id, centrality)
})
{
approx_rank.insert(node, rank as f64);
approx_rank.insert(node, rank as u64);
if top_nodes.len() < 1_000_000 {
top_nodes.push((graph.id2node(&node).unwrap(), centrality));
}
Expand Down
9 changes: 7 additions & 2 deletions crates/core/src/entrypoint/configure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ use tokio::io;
use tokio_stream::StreamExt;
use tracing::{debug, info};

use crate::config::{defaults, IndexingLocalConfig, LocalConfig, WebSpellConfig};
use crate::config::{
defaults, IndexingDualEncoderConfig, IndexingLocalConfig, LocalConfig, WebSpellConfig,
};
use crate::entrypoint::indexer::JobSettings;
use crate::entrypoint::{dmoz_parser, indexer};
use crate::Result;
Expand Down Expand Up @@ -202,7 +204,10 @@ fn create_inverted_index() -> Result<()> {
safety_classifier_path: None,
minimum_clean_words: None,
batch_size: defaults::Indexing::batch_size(),
dual_encoder_model_path: Some(dual_encoder_path.to_str().unwrap().to_string()),
dual_encoder: Some(IndexingDualEncoderConfig {
model_path: dual_encoder_path.to_str().unwrap().to_string(),
page_centrality_rank_threshold: Some(100_000),
}),
});

let index = job.process(&worker);
Expand Down
166 changes: 119 additions & 47 deletions crates/core/src/entrypoint/indexer/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use tracing::debug;

pub use super::indexable_webpage::IndexableWebpage;
pub use super::job::{Job, JobSettings};
use crate::config::{IndexingLocalConfig, LiveIndexConfig};
use crate::models::dual_encoder::DualEncoder;
use crate::config::{IndexingDualEncoderConfig, IndexingLocalConfig, LiveIndexConfig};
use crate::models::dual_encoder::DualEncoder as DualEncoderModel;
use crate::Result;

use crate::human_website_annotations;
Expand All @@ -41,7 +41,7 @@ pub struct Config {
pub page_webgraph_path: Option<String>,
pub topics_path: Option<String>,
pub safety_classifier_path: Option<String>,
pub dual_encoder_model_path: Option<String>,
pub dual_encoder: Option<IndexingDualEncoderConfig>,
}

impl From<IndexingLocalConfig> for Config {
Expand All @@ -52,7 +52,7 @@ impl From<IndexingLocalConfig> for Config {
page_webgraph_path: config.page_webgraph_path,
topics_path: config.topics_path,
safety_classifier_path: config.safety_classifier_path,
dual_encoder_model_path: config.dual_encoder_model_path,
dual_encoder: config.dual_encoder,
}
}
}
Expand All @@ -65,16 +65,21 @@ impl From<LiveIndexConfig> for Config {
page_webgraph_path: config.page_webgraph_path,
topics_path: None,
safety_classifier_path: config.safety_classifier_path,
dual_encoder_model_path: None,
dual_encoder: None,
}
}
}

struct DualEncoder {
model: DualEncoderModel,
page_centrality_rank_threshold: Option<u64>,
}

pub struct IndexingWorker {
host_centrality_store: RocksDbStore<NodeID, f64>,
host_centrality_rank_store: RocksDbStore<NodeID, f64>,
host_centrality_rank_store: RocksDbStore<NodeID, u64>,
page_centrality_store: Option<RocksDbStore<NodeID, f64>>,
page_centrality_rank_store: Option<RocksDbStore<NodeID, f64>>,
page_centrality_rank_store: Option<RocksDbStore<NodeID, u64>>,
page_webgraph: Option<Webgraph>,
topics: Option<human_website_annotations::Mapper>,
safety_classifier: Option<safety_classifier::Model>,
Expand Down Expand Up @@ -119,10 +124,16 @@ impl IndexingWorker {
.map(|path| safety_classifier::Model::open(path).unwrap()),
job_settings: None,
rake: RakeModel::default(),
dual_encoder: config.dual_encoder_model_path.as_ref().map(|path| {
DualEncoder::open(path).unwrap_or_else(|err| {
panic!("failed to open dual encoder model: {}", err);
})
dual_encoder: config.dual_encoder.as_ref().map(|dual_encoder| {
let model =
DualEncoderModel::open(&dual_encoder.model_path).unwrap_or_else(|err| {
panic!("failed to open dual encoder model: {}", err);
});

DualEncoder {
model,
page_centrality_rank_threshold: dual_encoder.page_centrality_rank_threshold,
}
}),
}
}
Expand Down Expand Up @@ -175,7 +186,7 @@ impl IndexingWorker {
let host_centrality_rank = self
.host_centrality_rank_store
.get(&host_node_id)
.unwrap_or(u64::MAX as f64);
.unwrap_or(u64::MAX);

if let Some(host_centrality_threshold) =
self.job_settings.and_then(|s| s.host_centrality_threshold)
Expand All @@ -194,10 +205,6 @@ impl IndexingWorker {
page.host_centrality = 0.0;
}

if !page.host_centrality_rank.is_finite() {
page.host_centrality_rank = u64::MAX as f64;
}

self.parse_text(page)?;

Ok(())
Expand Down Expand Up @@ -272,21 +279,17 @@ impl IndexingWorker {
page.page_centrality = store.get(&node_id).unwrap_or_default();
}

page.page_centrality_rank = u64::MAX as f64;
page.page_centrality_rank = u64::MAX;

if let Some(store) = self.page_centrality_rank_store.as_ref() {
let node_id = node.id();

page.page_centrality_rank = store.get(&node_id).unwrap_or(u64::MAX as f64);
page.page_centrality_rank = store.get(&node_id).unwrap_or(u64::MAX);
}

if !page.page_centrality.is_finite() {
page.page_centrality = 0.0;
}

if !page.page_centrality_rank.is_finite() {
page.page_centrality_rank = u64::MAX as f64;
}
}

fn set_dmoz_description(&self, page: &mut Webpage) {
Expand All @@ -310,43 +313,59 @@ impl IndexingWorker {
}

fn set_title_embeddings(&self, pages: &mut [Webpage]) {
if let Some(model) = self.dual_encoder.as_ref() {
let titles = pages
if let Some(dual_encoder) = self.dual_encoder.as_ref() {
let (page_indexes, titles): (Vec<_>, Vec<_>) = pages
.iter()
.map(|w| w.html.title().unwrap_or_default())
.collect::<Vec<_>>();
.enumerate()
.filter(|(_, w)| {
dual_encoder
.page_centrality_rank_threshold
.map(|thresh| w.page_centrality_rank <= thresh)
.unwrap_or(true)
})
.map(|(i, w)| (i, w.html.title().unwrap_or_default()))
.unzip();

let title_emb = model
let title_emb = dual_encoder
.model
.embed(&titles)
.ok()
.and_then(|t| t.to_dtype(candle_core::DType::BF16).ok());

if let Some(title_emb) = title_emb {
for (i, page) in pages.iter_mut().enumerate() {
for (i, page_index) in page_indexes.into_iter().enumerate() {
if let Ok(emb) = title_emb.get(i) {
page.title_embedding = Some(emb);
pages[page_index].title_embedding = Some(emb);
}
}
}
}
}

fn set_keyword_embeddings(&self, pages: &mut [Webpage]) {
if let Some(model) = self.dual_encoder.as_ref() {
let keywords = pages
if let Some(dual_encoder) = self.dual_encoder.as_ref() {
let (page_indexes, keywords): (Vec<_>, Vec<_>) = pages
.iter()
.map(|w| w.keywords.join("\n"))
.collect::<Vec<_>>();
.enumerate()
.filter(|(_, w)| {
dual_encoder
.page_centrality_rank_threshold
.map(|thresh| w.page_centrality_rank <= thresh)
.unwrap_or(true)
})
.map(|(i, w)| (i, w.keywords.join("\n")))
.unzip();

let keyword_emb = model
let keyword_emb = dual_encoder
.model
.embed(&keywords)
.ok()
.and_then(|t| t.to_dtype(candle_core::DType::BF16).ok());

if let Some(keyword_emb) = keyword_emb {
for (i, page) in pages.iter_mut().enumerate() {
for (i, page_index) in page_indexes.into_iter().enumerate() {
if let Ok(emb) = keyword_emb.get(i) {
page.keyword_embedding = Some(emb);
pages[page_index].keyword_embedding = Some(emb);
}
}
}
Expand Down Expand Up @@ -415,21 +434,17 @@ mod tests {

use super::*;

#[test]
fn title_embeddings() {
let data_path = Path::new("../../data/summarizer/dual_encoder");
if !data_path.exists() {
// Skip the test if the test data is not available
return;
}

let worker = IndexingWorker::new(IndexingLocalConfig {
fn setup_worker(data_path: &Path, threshold: Option<u64>) -> IndexingWorker {
IndexingWorker::new(IndexingLocalConfig {
host_centrality_store_path: crate::gen_temp_path().to_str().unwrap().to_string(),
page_centrality_store_path: None,
page_webgraph_path: None,
topics_path: None,
safety_classifier_path: None,
dual_encoder_model_path: Some(data_path.to_str().unwrap().to_string()),
dual_encoder: Some(IndexingDualEncoderConfig {
model_path: data_path.to_str().unwrap().to_string(),
page_centrality_rank_threshold: threshold,
}),
output_path: crate::gen_temp_path().to_str().unwrap().to_string(),
limit_warc_files: None,
skip_warc_files: None,
Expand All @@ -440,7 +455,17 @@ mod tests {
host_centrality_threshold: None,
minimum_clean_words: None,
batch_size: 10,
});
})
}

#[test]
fn title_embeddings() {
let data_path = Path::new("../../data/summarizer/dual_encoder");
if !data_path.exists() {
// Skip the test if the test data is not available
return;
}
let worker = setup_worker(data_path, None);

let webpages = vec![
IndexableWebpage {
Expand Down Expand Up @@ -480,6 +505,7 @@ mod tests {
.dual_encoder
.as_ref()
.unwrap()
.model
.embed(&[query.to_string()])
.unwrap()
.to_dtype(candle_core::DType::F16)
Expand Down Expand Up @@ -533,4 +559,50 @@ mod tests {

assert!(sim1 > sim2);
}

#[test]
fn title_embedding_ranks() {
let data_path = Path::new("../../data/summarizer/dual_encoder");
if !data_path.exists() {
// Skip the test if the test data is not available
return;
}
let worker = setup_worker(data_path, Some(100_000));

let mut a = Webpage::test_parse("<html><head><title>Homemade Heart Brownie Recipe</title></head><body>Example</body></html>", "https://a.com").unwrap();
a.page_centrality_rank = 1;

let mut b = Webpage::test_parse("<html><head><title>How To Use an iMac as a Monitor for a PC</title></head><body>Example</body></html>", "https://b.com").unwrap();
b.page_centrality_rank = 1_000_000;

let mut webpages = vec![a, b];
worker.set_title_embeddings(&mut webpages);

assert_eq!(webpages.len(), 2);
assert_eq!(
webpages[0].html.title(),
Some("Homemade Heart Brownie Recipe".to_string())
);

assert!(webpages[0].title_embedding.is_some());
assert!(webpages[1].title_embedding.is_none());

let mut a = Webpage::test_parse("<html><head><title>Homemade Heart Brownie Recipe</title></head><body>Example</body></html>", "https://a.com").unwrap();
a.page_centrality_rank = 1_000_000;

let mut b = Webpage::test_parse("<html><head><title>How To Use an iMac as a Monitor for a PC</title></head><body>Example</body></html>", "https://b.com").unwrap();
b.page_centrality_rank = 1;

let mut webpages = vec![a, b];
worker.set_title_embeddings(&mut webpages);

assert_eq!(webpages.len(), 2);
assert_eq!(
webpages[0].html.title(),
Some("Homemade Heart Brownie Recipe".to_string())
);

assert!(webpages[0].title_embedding.is_none());
assert!(webpages[1].title_embedding.is_some());
}
}
Loading

0 comments on commit 40a9260

Please sign in to comment.