Skip to content

Commit

Permalink
fix(TextEncoder): domjit crash in encode (#13320)
Browse files Browse the repository at this point in the history
Co-authored-by: Jarred Sumner <[email protected]>
  • Loading branch information
dylan-conway and Jarred-Sumner authored Aug 15, 2024
1 parent b70458c commit 5bd3442
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/bun.js/bindings/Uint8Array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ extern "C" JSC::EncodedJSValue JSUint8Array__fromDefaultAllocator(JSC::JSGlobalO
mi_free(p);
}));

uint8Array = JSC::JSUint8Array::create(lexicalGlobalObject, lexicalGlobalObject->m_typedArrayUint8.get(lexicalGlobalObject), WTFMove(buffer), 0, length);
uint8Array = JSC::JSUint8Array::create(lexicalGlobalObject, lexicalGlobalObject->typedArrayStructure(JSC::TypeUint8, false), WTFMove(buffer), 0, length);
} else {
uint8Array = JSC::JSUint8Array::create(lexicalGlobalObject, lexicalGlobalObject->m_typedArrayUint8.get(lexicalGlobalObject), 0);
uint8Array = JSC::JSUint8Array::create(lexicalGlobalObject, lexicalGlobalObject->typedArrayStructure(JSC::TypeUint8, false), 0);
}

return JSC::JSValue::encode(uint8Array);
Expand Down
6 changes: 4 additions & 2 deletions src/bun.js/bindings/ZigGlobalObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1719,10 +1719,12 @@ extern "C" JSC__JSValue Bun__allocUint8ArrayForCopy(JSC::JSGlobalObject* globalO

extern "C" JSC__JSValue Bun__createUint8ArrayForCopy(JSC::JSGlobalObject* globalObject, const void* ptr, size_t len, bool isBuffer)
{
auto scope = DECLARE_THROW_SCOPE(globalObject->vm());
VM& vm = globalObject->vm();
auto scope = DECLARE_THROW_SCOPE(vm);

JSC::JSUint8Array* array = JSC::JSUint8Array::createUninitialized(
globalObject,
isBuffer ? reinterpret_cast<Zig::GlobalObject*>(globalObject)->JSBufferSubclassStructure() : globalObject->m_typedArrayUint8.get(globalObject),
isBuffer ? reinterpret_cast<Zig::GlobalObject*>(globalObject)->JSBufferSubclassStructure() : globalObject->typedArrayStructure(TypeUint8, false),
len);

if (UNLIKELY(!array)) {
Expand Down
2 changes: 1 addition & 1 deletion src/bun.js/bindings/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2986,7 +2986,7 @@ JSC__JSValue ZigString__toExternalValue(const ZigString* arg0, JSC__JSGlobalObje

VirtualMachine* JSC__JSGlobalObject__bunVM(JSC__JSGlobalObject* arg0)
{
return reinterpret_cast<VirtualMachine*>(reinterpret_cast<Zig::GlobalObject*>(arg0)->bunVM());
return reinterpret_cast<VirtualMachine*>(WebCore::clientData(arg0->vm())->bunVM);
}

JSC__JSValue ZigString__toValueGC(const ZigString* arg0, JSC__JSGlobalObject* arg1)
Expand Down
9 changes: 9 additions & 0 deletions src/bun.js/bindings/bindings.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1711,6 +1711,15 @@ pub const JSUint8Array = opaque {
pub fn fromBytes(globalThis: *JSGlobalObject, bytes: []u8) JSC.JSValue {
return JSUint8Array__fromDefaultAllocator(globalThis, bytes.ptr, bytes.len);
}

extern fn Bun__createUint8ArrayForCopy(*JSC.JSGlobalObject, ptr: ?*const anyopaque, len: usize, buffer: bool) JSValue;
pub fn fromBytesCopy(globalThis: *JSGlobalObject, bytes: []const u8) JSValue {
return Bun__createUint8ArrayForCopy(globalThis, bytes.ptr, bytes.len, false);
}

pub fn createEmpty(globalThis: *JSGlobalObject) JSValue {
return Bun__createUint8ArrayForCopy(globalThis, null, 0, false);
}
};

pub const JSCell = extern struct {
Expand Down
5 changes: 5 additions & 0 deletions src/bun.js/webcore/encoding.classes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ export default [
flush: {
fn: "flush",
length: 0,

DOMJIT: {
returns: "JSUint8Array",
args: [],
},
},
},
}),
Expand Down
13 changes: 7 additions & 6 deletions src/bun.js/webcore/encoding.zig
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const strings = bun.strings;
const string = bun.string;
const FeatureFlags = bun.FeatureFlags;
const ArrayBuffer = @import("../base.zig").ArrayBuffer;
const JSUint8Array = JSC.JSUint8Array;
const Properties = @import("../base.zig").Properties;

const castObj = @import("../base.zig").castObj;
Expand Down Expand Up @@ -450,7 +451,7 @@ pub const TextEncoderStreamEncoder = struct {
fn encodeLatin1(this: *TextEncoderStreamEncoder, globalObject: *JSGlobalObject, input: []const u8) JSValue {
log("encodeLatin1: \"{s}\"", .{input});

if (input.len == 0) return .undefined;
if (input.len == 0) return JSUint8Array.createEmpty(globalObject);

const prepend_replacement_len: usize = prepend_replacement: {
if (this.pending_lead_surrogate != null) {
Expand Down Expand Up @@ -509,7 +510,7 @@ pub const TextEncoderStreamEncoder = struct {
fn encodeUTF16(this: *TextEncoderStreamEncoder, globalObject: *JSGlobalObject, input: []const u16) JSValue {
log("encodeUTF16: \"{}\"", .{bun.fmt.utf16(input)});

if (input.len == 0) return .undefined;
if (input.len == 0) return JSUint8Array.createEmpty(globalObject);

const Prepend = struct {
bytes: [4]u8,
Expand Down Expand Up @@ -538,7 +539,7 @@ pub const TextEncoderStreamEncoder = struct {

remain = remain[1..];
if (remain.len == 0) {
return ArrayBuffer.createBuffer(
return JSUint8Array.fromBytesCopy(
globalObject,
sequence[0..converted.utf8Width()],
);
Expand Down Expand Up @@ -579,7 +580,7 @@ pub const TextEncoderStreamEncoder = struct {

if (lead_surrogate) |pending_lead| {
this.pending_lead_surrogate = pending_lead;
if (buf.items.len == 0) return .undefined;
if (buf.items.len == 0) return JSUint8Array.createEmpty(globalObject);
}

return JSC.JSUint8Array.fromBytes(globalObject, buf.items);
Expand All @@ -601,9 +602,9 @@ pub const TextEncoderStreamEncoder = struct {

fn flushBody(this: *TextEncoderStreamEncoder, globalObject: *JSGlobalObject) JSValue {
return if (this.pending_lead_surrogate == null)
.undefined
JSUint8Array.createEmpty(globalObject)
else
JSC.ArrayBuffer.createBuffer(globalObject, &.{ 0xef, 0xbf, 0xbd });
JSUint8Array.fromBytesCopy(globalObject, &.{ 0xef, 0xbf, 0xbd });
}
};

Expand Down
62 changes: 60 additions & 2 deletions src/codegen/generate-classes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ static const JSC::DOMJIT::Signature DOMJITSignatureFor${fnName}(${DOMJITName(fnN
);
}

function DOMJITFunctionDefinition(jsClassName, fnName, symName, { args }) {
function DOMJITFunctionDefinition(jsClassName, fnName, symName, { args }, fn) {
const argNames = args.map((arg, i) => `${argTypeName(arg)} arg${i}`);
const formattedArgs = argNames.length > 0 ? `, ${argNames.join(", ")}` : "";
const retArgs = argNames.length > 0 ? `, ${args.map((b, i) => "arg" + i).join(", ")}` : "";
Expand All @@ -147,6 +147,24 @@ JSC_DEFINE_JIT_OPERATION(${DOMJITName(
CallFrame* callFrame = DECLARE_CALL_FRAME(vm);
IGNORE_WARNINGS_END
JSC::JITOperationPrologueCallFrameTracer tracer(vm, callFrame);
#ifdef BUN_DEBUG
${jsClassName}* wrapper = reinterpret_cast<${jsClassName}*>(thisValue);
JSC::EncodedJSValue result = ${DOMJITName(symName)}(wrapper->wrapped(), lexicalGlobalObject${retArgs});
JSValue decoded = JSValue::decode(result);
if (wrapper->m_${fn}_expectedResultType) {
if (decoded.isCell() && !decoded.isEmpty()) {
ASSERT_WITH_MESSAGE(wrapper->m_${fn}_expectedResultType.value().has_value(), "DOMJIT function return type changed!");
ASSERT_WITH_MESSAGE(wrapper->m_${fn}_expectedResultType.value().value() == decoded.asCell()->type(), "DOMJIT function return type changed!");
} else {
ASSERT_WITH_MESSAGE(!wrapper->m_${fn}_expectedResultType.value().has_value(), "DOMJIT function return type changed!");
}
} else if (!decoded.isEmpty()) {
wrapper->m_${fn}_expectedResultType = decoded.isCell()
? std::optional<JSC::JSType>(decoded.asCell()->type())
: std::optional<JSC::JSType>(std::nullopt);
}
return { result };
#endif
return {${DOMJITName(symName)}(reinterpret_cast<${jsClassName}*>(thisValue)->wrapped(), lexicalGlobalObject${retArgs})};
}
`.trim();
Expand Down Expand Up @@ -853,6 +871,7 @@ function renderDecls(symbolName, typeName, proto, supportsObjectCreate = false)
symbolName(typeName, name),
symbolName(typeName, proto[name].fn),
proto[name].DOMJIT,
proto[name].fn,
),
);
}
Expand Down Expand Up @@ -1089,6 +1108,7 @@ JSC_DEFINE_CUSTOM_SETTER(${symbolName(typeName, name)}SetterWrap, (JSGlobalObjec
}

if ("fn" in proto[name]) {
const fn = proto[name].fn;
rows.push(`
JSC_DEFINE_HOST_FUNCTION(${symbolName(typeName, name)}Callback, (JSGlobalObject * lexicalGlobalObject, CallFrame* callFrame))
{
Expand All @@ -1114,10 +1134,29 @@ JSC_DEFINE_HOST_FUNCTION(${symbolName(typeName, name)}Callback, (JSGlobalObject
lastFileName = fileName;
}
JSC::EncodedJSValue result = ${symbolName(typeName, proto[name].fn)}(thisObject->wrapped(), lexicalGlobalObject, callFrame${proto[name].passThis ? ", JSValue::encode(thisObject)" : ""});
JSC::EncodedJSValue result = ${symbolName(typeName, fn)}(thisObject->wrapped(), lexicalGlobalObject, callFrame, callFrame${proto[name].passThis ? ", JSValue::encode(thisObject)" : ""});
ASSERT_WITH_MESSAGE(!JSValue::decode(result).isEmpty() or DECLARE_CATCH_SCOPE(vm).exception() != 0, \"${typeName}.${proto[name].fn} returned an empty value without an exception\");
${
!proto[name].DOMJIT
? ""
: `
JSValue decoded = JSValue::decode(result);
if (thisObject->m_${fn}_expectedResultType) {
if (decoded.isCell() && !decoded.isEmpty()) {
ASSERT_WITH_MESSAGE(thisObject->m_${fn}_expectedResultType.value().has_value(), "DOMJIT function return type changed!");
ASSERT_WITH_MESSAGE(thisObject->m_${fn}_expectedResultType.value().value() == decoded.asCell()->type(), "DOMJIT function return type changed!");
} else {
ASSERT_WITH_MESSAGE(!thisObject->m_${fn}_expectedResultType.value().has_value(), "DOMJIT function return type changed!");
}
} else if (!decoded.isEmpty()) {
thisObject->m_${fn}_expectedResultType = decoded.isCell()
? std::optional<JSC::JSType>(decoded.asCell()->type())
: std::optional<JSC::JSType>(std::nullopt);
}`
}
return result;
#endif
Expand Down Expand Up @@ -1265,6 +1304,8 @@ function generateClassHeader(typeName, obj: ClassDefinition) {
})
.join("\n")}
${domJITTypeCheckFields(proto, klass)}
${weakOwner}
${DECLARE_VISIT_CHILDREN}
Expand All @@ -1276,6 +1317,23 @@ function generateClassHeader(typeName, obj: ClassDefinition) {
`.trim();
}

function domJITTypeCheckFields(proto, klass) {
var output = "#ifdef BUN_DEBUG\n";
for (const name in proto) {
const { DOMJIT, fn } = proto[name];
if (!DOMJIT) continue;
output += `std::optional<std::optional<JSC::JSType>> m_${fn}_expectedResultType = std::nullopt;\n`;
}

for (const name in klass) {
const { DOMJIT, fn } = klass[name];
if (!DOMJIT) continue;
output += `std::optional<std::optional<JSC::JSType>> m_${fn}_expectedResultType = std::nullopt;\n`;
}
output += "#endif\n";
return output;
}

function generateClassImpl(typeName, obj: ClassDefinition) {
const {
klass: fields,
Expand Down
7 changes: 3 additions & 4 deletions src/js/builtins/TextEncoderStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,12 @@ export function initializeTextEncoderStream() {
};
const transformAlgorithm = chunk => {
const encoder = $getByIdDirectPrivate(this, "textEncoderStreamEncoder");
let buffer;
try {
buffer = encoder.encode(chunk);
var buffer = encoder.encode(chunk);
} catch (e) {
return Promise.$reject(e);
}
if (buffer) {
if (buffer.length) {
const transformStream = $getByIdDirectPrivate(this, "textEncoderStreamTransform");
const controller = $getByIdDirectPrivate(transformStream, "controller");
$transformStreamDefaultControllerEnqueue(controller, buffer);
Expand All @@ -45,7 +44,7 @@ export function initializeTextEncoderStream() {
const flushAlgorithm = () => {
const encoder = $getByIdDirectPrivate(this, "textEncoderStreamEncoder");
const buffer = encoder.flush();
if (buffer) {
if (buffer.length) {
const transformStream = $getByIdDirectPrivate(this, "textEncoderStreamTransform");
const controller = $getByIdDirectPrivate(transformStream, "controller");
$transformStreamDefaultControllerEnqueue(controller, buffer);
Expand Down
5 changes: 5 additions & 0 deletions test/js/bun/jsc/domjit.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,18 @@ describe("DOMJIT", () => {
describe("in NodeVM", () => {
const code = `
const buf = new Uint8Array(4);
const encoder = new TextEncoder();
for (let iter of [100000]) {
for (let i = 0; i < iter; i++) {
performance.now();
}
for (let i = 0; i < iter; i++) {
new TextEncoder().encode("test");
}
const str = "a".repeat(1030);
for (let i = 0; i < 1000000; i++) {
const result = encoder.encode(str);
}
for (let i = 0; i < iter; i++) {
new TextEncoder().encodeInto("test", buf);
}
Expand Down
1 change: 1 addition & 0 deletions test/js/web/encoding/text-encoder-stream.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ for (const { input, output, description } of testCases) {
const chunkArray = await Bun.readableStreamToArray(outputStream);
expect(chunkArray.length, "number of chunks should match").toBe(output.length);
for (let i = 0; i < output.length; ++i) {
expect(chunkArray[i].constructor).toBe(Uint8Array);
expect(chunkArray[i].length).toBe(output[i].length);
for (let j = 0; j < output[i].length; ++j) {
expect(chunkArray[i][j]).toBe(output[i][j]);
Expand Down

0 comments on commit 5bd3442

Please sign in to comment.