-
Notifications
You must be signed in to change notification settings - Fork 11
/
example.js
58 lines (49 loc) · 2.38 KB
/
example.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import pg from 'pg';
import pgvector from 'pgvector/pg';
import { from as copyFrom } from 'pg-copy-streams';
import { stdout } from 'process';
// generate random data
const rows = 100000;
const dimensions = 128;
const embeddings = Array.from({length: rows}, () => Array.from({length: dimensions}, () => Math.random()));
const categories = Array.from({length: rows}, () => Math.floor(Math.random() * 100));
const queries = Array.from({length: 10}, () => Array.from({length: dimensions}, () => Math.random()));
// enable extensions
let client = new pg.Client({database: 'pgvector_citus'});
await client.connect();
await client.query('CREATE EXTENSION IF NOT EXISTS citus');
await client.query('CREATE EXTENSION IF NOT EXISTS vector');
// GUC variables set on the session do not propagate to Citus workers
// https://github.com/citusdata/citus/issues/462
// you can either:
// 1. set them on the system, user, or database and reconnect
// 2. set them for a transaction with SET LOCAL
await client.query("ALTER DATABASE pgvector_citus SET maintenance_work_mem = '512MB'");
await client.query('ALTER DATABASE pgvector_citus SET hnsw.ef_search = 20');
await client.end();
// reconnect for updated GUC variables to take effect
client = new pg.Client({database: 'pgvector_citus'});
await client.connect();
await pgvector.registerTypes(client);
console.log('Creating distributed table');
await client.query('DROP TABLE IF EXISTS items');
await client.query(`CREATE TABLE items (id bigserial, embedding vector(${dimensions}), category_id bigint, PRIMARY KEY (id, category_id))`);
await client.query('SET citus.shard_count = 4');
await client.query("SELECT create_distributed_table('items', 'category_id')");
console.log('Loading data in parallel');
const stream = client.query(copyFrom('COPY items (embedding, category_id) FROM STDIN'));
for (const [i, embedding] of embeddings.entries()) {
const line = `${pgvector.toSql(embedding)}\t${categories[i]}\n`;
stream.flushChunk(line);
}
stream.on('finish', async function () {
console.log('Creating index in parallel');
await client.query('CREATE INDEX ON items USING hnsw (embedding vector_l2_ops)');
console.log('Running distributed queries');
for (const query of queries) {
const { rows } = await client.query('SELECT id FROM items ORDER BY embedding <-> $1 LIMIT 5', [pgvector.toSql(query)]);
console.log(rows.map((r) => r.id));
}
client.end();
});
stream.end();