Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Readback results from HVM into Rust through interop #389

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/ast.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use TSPL::{new_parser, Parser};
use highlight_error::highlight_error;
use crate::hvm;
use crate::{hvm, interop};
use std::fmt::{Debug, Display};
use std::collections::{BTreeMap, BTreeSet};

Expand Down Expand Up @@ -364,7 +364,7 @@ impl Book {
// --------

impl Tree {
pub fn readback(net: &hvm::GNet, port: hvm::Port, fids: &BTreeMap<hvm::Val, String>) -> Option<Tree> {
pub fn readback<N: interop::NetReadback>(net: &mut N, port: hvm::Port, fids: &BTreeMap<hvm::Val, String>) -> Option<Tree> {
//println!("reading {}", port.show());
match port.get_tag() {
hvm::VAR => {
Expand Down Expand Up @@ -416,7 +416,7 @@ impl Tree {
}

impl Net {
pub fn readback(net: &hvm::GNet, book: &hvm::Book) -> Option<Net> {
pub fn readback<N: interop::NetReadback>(net: &mut N, book: &hvm::Book) -> Option<Net> {
let mut fids = BTreeMap::new();
for (fid, def) in book.defs.iter().enumerate() {
fids.insert(fid as hvm::Val, def.name.clone());
Expand Down
59 changes: 45 additions & 14 deletions src/hvm.c
Original file line number Diff line number Diff line change
Expand Up @@ -1755,10 +1755,24 @@ void pretty_print_port(Net* net, Book* book, Port port) {
void do_run_io(Net* net, Book* book, Port port);
#endif

// Main
// Output Net
// Used by other languages calling `hvm_c`
// ----
typedef struct OutputNet {
void *original;
APair *node_buf;
APort *vars_buf;
a64 itrs;
} OutputNet;

void free_output_net_c(OutputNet* net) {
free((Net*)net->original);
free(net);
}

void hvm_c(u32* book_buffer) {
// Main
// ----
OutputNet* hvm_c(u32* book_buffer, bool return_output) {
// Creates static TMs
alloc_static_tms();

Expand All @@ -1781,33 +1795,50 @@ void hvm_c(u32* book_buffer) {

#ifdef IO
do_run_io(net, book, ROOT);
// IO actions into `stdout` and `stderr` may appear
// after Rust `print`s if we don't flush
fflush(stdout);
fflush(stderr);
#else
normalize(net, book);
#endif

// Prints the result
printf("Result: ");
pretty_print_port(net, book, enter(net, ROOT));
printf("\n");

// Stops the timer
double duration = (time64() - start) / 1000000000.0; // seconds

// Prints interactions and time
u64 itrs = atomic_load(&net->itrs);
printf("- ITRS: %" PRIu64 "\n", itrs);
printf("- TIME: %.2fs\n", duration);
printf("- MIPS: %.2f\n", (double)itrs / duration / 1000000.0);
if (!return_output) {
// Prints the result
printf("Result: ");
pretty_print_port(net, book, enter(net, ROOT));
printf("\n");

// Prints interactions and time
u64 itrs = atomic_load(&net->itrs);
printf("- ITRS: %" PRIu64 "\n", itrs);
printf("- TIME: %.2fs\n", duration);
printf("- MIPS: %.2f\n", (double)itrs / duration / 1000000.0);
}

// Frees everything
free_static_tms();
free(net);
free(book);

if (return_output) {
OutputNet *output = malloc(sizeof(OutputNet));
output->original = (void*)net;
output->node_buf = &net->node_buf[0];
output->vars_buf = &net->vars_buf[0];
output->itrs = atomic_load(&net->itrs);
return output;
}

free(net);
return NULL;
}

#ifdef WITH_MAIN
int main() {
hvm_c((u32*)BOOK_BUF);
hvm_c((u32*)BOOK_BUF, NULL);
return 0;
}
#endif
50 changes: 47 additions & 3 deletions src/hvm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2279,14 +2279,45 @@ __global__ void print_result(GNet* gnet) {

//COMPILED_BOOK_BUF//

// Output Net
// Used by other languages calling `hvm_cu`
// ----
struct OutputNet {
void *original;
Pair *node_buf;
Port *vars_buf;
u64 itrs;
};

OutputNet* create_output_net(GNet* gnet) {
OutputNet* output = (OutputNet*)malloc(sizeof(OutputNet));

// Allocate host memory for the net
GNet* h_gnet = (GNet*)malloc(sizeof(GNet));

// Copy the net from device to host
cudaMemcpy(h_gnet, gnet, sizeof(GNet), cudaMemcpyDeviceToHost);

output->original = (void*)h_gnet;
output->node_buf = h_gnet->node_buf;
output->vars_buf = h_gnet->vars_buf;
output->itrs = h_gnet->itrs;
return output;
}

extern "C" void free_output_net_cuda(OutputNet* net) {
free((GNet*)net->original);
free(net);
}

// Main
// ----

#ifdef IO
void do_run_io(GNet* gnet, Book* book, Port port);
#endif

extern "C" void hvm_cu(u32* book_buffer) {
extern "C" OutputNet* hvm_cu(u32* book_buffer, bool return_output) {
// Start the timer
clock_t start = clock();

Expand All @@ -2308,6 +2339,10 @@ extern "C" void hvm_cu(u32* book_buffer) {

#ifdef IO
do_run_io(gnet, book, ROOT);
// IO actions into `stdout` and `stderr` may appear
// after Rust `print`s if we don't flush
fflush(stdout);
fflush(stderr);
#else
gnet_normalize(gnet);
#endif
Expand All @@ -2319,7 +2354,10 @@ extern "C" void hvm_cu(u32* book_buffer) {
double duration = ((double)(end - start)) / CLOCKS_PER_SEC;

// Prints the result
print_result<<<1,1>>>(gnet);
// If `output` is set, the Rust implementation will print the net
if (!return_output) {
print_result<<<1,1>>>(gnet);
}

// Reports errors
cudaError_t err = cudaGetLastError();
Expand Down Expand Up @@ -2355,15 +2393,21 @@ extern "C" void hvm_cu(u32* book_buffer) {
//cudaMemcpy(&itrs, &gnet->itrs, sizeof(u64), cudaMemcpyDeviceToHost);

// Prints interactions, time and MIPS
// If `output` is set, the Rust implementation will print the net
if (return_output) {
return create_output_net(gnet);
}

printf("- ITRS: %llu\n", gnet_get_itrs(gnet));
printf("- LEAK: %llu\n", gnet_get_leak(gnet));
printf("- TIME: %.2fs\n", duration);
printf("- MIPS: %.2f\n", (double)gnet_get_itrs(gnet) / duration / 1000000.0);
return NULL;
}

#ifdef WITH_MAIN
int main() {
hvm_cu((u32*)BOOK_BUF);
hvm_cu((u32*)BOOK_BUF, false);
return 0;
}
#endif
4 changes: 4 additions & 0 deletions src/hvm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@ pub type Val = u32; // Val ::= 29-bit (rounded up to u32)
pub type Rule = u8; // Rule ::= 8-bit (fits a u8)

// Port
#[repr(C)]
#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Hash)]
pub struct Port(pub Val);

// Pair
#[repr(C)]
pub struct Pair(pub u64);

// Atomics
pub type AVal = AtomicU32;
#[repr(C)]
pub struct APort(pub AVal);
#[repr(C)]
pub struct APair(pub AtomicU64);

// Number
Expand Down
Loading
Loading