diff --git a/Cargo.lock b/Cargo.lock index 345d262..beb2d06 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -172,7 +172,7 @@ dependencies = [ ] [[package]] -name = "example-iface-method-returns-string" +name = "example-iface-method-returns" version = "0.0.2" dependencies = [ "wit-bindgen", diff --git a/cmd/gravity/src/codegen/exports.rs b/cmd/gravity/src/codegen/exports.rs index 9005aba..aa28749 100644 --- a/cmd/gravity/src/codegen/exports.rs +++ b/cmd/gravity/src/codegen/exports.rs @@ -1,7 +1,10 @@ use genco::prelude::*; use wit_bindgen_core::wit_parser::{Function, Resolve, SizeAlign, World, WorldItem}; -use crate::go::{GoIdentifier, GoResult, GoType, imports::CONTEXT_CONTEXT}; +use crate::{ + Func, + go::{GoIdentifier, GoResult, GoType, imports::CONTEXT_CONTEXT}, +}; pub struct ExportConfig<'a> { pub instance: &'a GoIdentifier, @@ -49,7 +52,7 @@ impl<'a> ExportGenerator<'a> { GoResult::Empty }; - let mut f = crate::Func::export(result, self.config.sizes); + let mut f = Func::new(result, self.config.sizes); wit_bindgen_core::abi::call( self.config.resolve, wit_bindgen_core::abi::AbiVariant::GuestExport, diff --git a/cmd/gravity/src/codegen/func.rs b/cmd/gravity/src/codegen/func.rs index a18c97b..f203765 100644 --- a/cmd/gravity/src/codegen/func.rs +++ b/cmd/gravity/src/codegen/func.rs @@ -15,27 +15,10 @@ use crate::{ WAZERO_API_ENCODE_F64, WAZERO_API_ENCODE_I32, WAZERO_API_ENCODE_U32, }, }, - resolve_type, resolve_wasm_type, + resolve_type, }; -/// The direction of a function. -/// -/// Functions in the Component Model can be imported into a world or -/// exported from a world. -enum Direction<'a> { - /// The function is imported into the world. - Import { - /// The name of the parameter representing the interface instance - /// in the generated host binding function. - param_name: &'a GoIdentifier, - }, - /// The function is exported from the world. - #[allow(dead_code, reason = "halfway through refactor of func bindings")] - Export, -} - pub struct Func<'a> { - direction: Direction<'a>, args: Vec, result: GoResult, tmp: usize, @@ -47,24 +30,8 @@ pub struct Func<'a> { impl<'a> Func<'a> { /// Create a new exported function. - #[allow(dead_code, reason = "halfway through refactor of func bindings")] - pub fn export(result: GoResult, sizes: &'a SizeAlign) -> Self { - Self { - direction: Direction::Export, - args: Vec::new(), - result, - tmp: 0, - body: Tokens::new(), - block_storage: Vec::new(), - blocks: Vec::new(), - sizes, - } - } - - /// Create a new exported function. - pub fn import(param_name: &'a GoIdentifier, result: GoResult, sizes: &'a SizeAlign) -> Self { + pub fn new(result: GoResult, sizes: &'a SizeAlign) -> Self { Self { - direction: Direction::Import { param_name }, args: Vec::new(), result, tmp: 0, @@ -138,46 +105,33 @@ impl Bindgen for Func<'_> { let memory = &format!("memory{tmp}"); let realloc = &format!("realloc{tmp}"); let operand = &operands[0]; - match self.direction { - Direction::Export => { - quote_in! { self.body => - $['\r'] - $memory := i.module.Memory() - $realloc := i.module.ExportedFunction($(quoted(*realloc_name))) - $ptr, $len, $err := writeString(ctx, $operand, $memory, $realloc) - $(match &self.result { - GoResult::Anon(GoType::ValueOrError(typ)) => { - if $err != nil { - var $default $(typ.as_ref()) - return $default, $err - } - } - GoResult::Anon(GoType::Error) => { - if $err != nil { - return $err - } - } - GoResult::Anon(_) | GoResult::Empty => { - $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) - if $err != nil { - panic($err) - } - } - }) + + quote_in! { self.body => + $['\r'] + $memory := i.module.Memory() + $realloc := i.module.ExportedFunction($(quoted(*realloc_name))) + $ptr, $len, $err := writeString(ctx, $operand, $memory, $realloc) + $(match &self.result { + GoResult::Anon(GoType::ValueOrError(typ)) => { + if $err != nil { + var $default $(typ.as_ref()) + return $default, $err + } } - } - Direction::Import { .. } => { - quote_in! { self.body => - $['\r'] - $memory := mod.Memory() - $realloc := mod.ExportedFunction($(quoted(*realloc_name))) - $ptr, $len, $err := writeString(ctx, $operand, $memory, $realloc) + GoResult::Anon(GoType::Error) => { + if $err != nil { + return $err + } + } + GoResult::Anon(_) | GoResult::Empty => { + $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) if $err != nil { panic($err) } - }; - } + } + }) } + results.push(Operand::SingleValue(ptr.into())); results.push(Operand::SingleValue(len.into())); } @@ -290,7 +244,7 @@ impl Bindgen for Func<'_> { let operand = &operands[0]; quote_in! { self.body => $['\r'] - var $(&value) uint32 + var $(&value) uint64 if $operand { $(&value) = 1 } else { @@ -400,6 +354,7 @@ impl Bindgen for Func<'_> { let offset = offset.size_wasm32(); let tmp = self.tmp(); let value = &format!("value{tmp}"); + let encoded = &format!("encoded{tmp}"); let ok = &format!("ok{tmp}"); let default = &format!("default{tmp}"); let operand = &operands[0]; @@ -425,8 +380,9 @@ impl Bindgen for Func<'_> { } } }) + $encoded := $WAZERO_API_ENCODE_U32($value) }; - results.push(Operand::SingleValue(value.into())); + results.push(Operand::SingleValue(encoded.into())); } Instruction::StringLift => { let tmp = self.tmp(); @@ -436,44 +392,32 @@ impl Bindgen for Func<'_> { let str = &format!("str{tmp}"); let ptr = &operands[0]; let len = &operands[1]; - match self.direction { - Direction::Export { .. } => { - quote_in! { self.body => - $['\r'] - $buf, $ok := i.module.Memory().Read($ptr, $len) - $(match &self.result { - GoResult::Anon(GoType::ValueOrError(typ)) => { - if !$ok { - var $default $(typ.as_ref()) - return $default, $ERRORS_NEW("failed to read bytes from memory") - } - } - GoResult::Anon(GoType::Error) => { - if !$ok { - return $ERRORS_NEW("failed to read bytes from memory") - } - } - GoResult::Anon(_) | GoResult::Empty => { - $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) - if !$ok { - panic($ERRORS_NEW("failed to read bytes from memory")) - } - } - }) - $str := string($buf) - }; - } - Direction::Import { .. } => { - quote_in! { self.body => - $['\r'] - $buf, $ok := mod.Memory().Read($ptr, $len) + + quote_in! { self.body => + $['\r'] + $buf, $ok := i.module.Memory().Read($ptr, $len) + $(match &self.result { + GoResult::Anon(GoType::ValueOrError(typ)) => { + if !$ok { + var $default $(typ.as_ref()) + return $default, $ERRORS_NEW("failed to read bytes from memory") + } + } + GoResult::Anon(GoType::Error) => { + if !$ok { + return $ERRORS_NEW("failed to read bytes from memory") + } + } + GoResult::Anon(_) | GoResult::Empty => { + $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) if !$ok { panic($ERRORS_NEW("failed to read bytes from memory")) } - $str := string($buf) - }; - } - } + } + }) + $str := string($buf) + }; + results.push(Operand::SingleValue(str.into())); } Instruction::ResultLift { @@ -559,57 +503,8 @@ impl Bindgen for Func<'_> { }; } } - Instruction::CallInterface { func, .. } => { - let ident = GoIdentifier::public(&func.name); - let tmp = self.tmp(); - let args = quote!($(for op in operands.iter() join (, ) => $op)); - let returns = match &func.result { - None => GoType::Nothing, - Some(typ) => resolve_type(typ, resolve), - }; - let value = &format!("value{tmp}"); - let err = &format!("err{tmp}"); - let ok = &format!("ok{tmp}"); - match self.direction { - Direction::Export { .. } => todo!("TODO(#10): handle export direction"), - Direction::Import { param_name, .. } => { - quote_in! { self.body => - $['\r'] - $(match returns { - GoType::Nothing => $param_name.$ident(ctx, $args), - GoType::Bool | GoType::Uint32 | GoType::Interface | GoType::String | GoType::UserDefined(_) => $value := $param_name.$ident(ctx, $args), - GoType::Error => $err := $param_name.$ident(ctx, $args), - GoType::ValueOrError(_) => { - $value, $err := $param_name.$ident(ctx, $args) - } - GoType::ValueOrOk(_) => { - $value, $ok := $param_name.$ident(ctx, $args) - } - _ => $(comment(&["TODO(#9): handle return type"])) - }) - } - } - } - match returns { - GoType::Nothing => (), - GoType::Bool - | GoType::Uint32 - | GoType::Interface - | GoType::UserDefined(_) - | GoType::String => { - results.push(Operand::SingleValue(value.into())); - } - GoType::Error => { - results.push(Operand::SingleValue(err.into())); - } - GoType::ValueOrError(_) => { - results.push(Operand::MultiValue((value.into(), err.into()))); - } - GoType::ValueOrOk(_) => { - results.push(Operand::MultiValue((value.into(), ok.into()))) - } - _ => todo!("TODO(#9): handle return type - {returns:?}"), - } + Instruction::CallInterface { .. } => { + todo!("TODO(#10): handle exported CallInterface") } Instruction::VariantPayloadName => { results.push(Operand::SingleValue("variantPayload".into())); @@ -621,55 +516,26 @@ impl Bindgen for Func<'_> { let tag = &operands[0]; let ptr = &operands[1]; if let Operand::Literal(byte) = tag { - match &self.direction { - Direction::Export => { - quote_in! { self.body => - $['\r'] - i.module.Memory().WriteByte($ptr+$offset, $byte) - } - } - Direction::Import { .. } => { - quote_in! { self.body => - $['\r'] - mod.Memory().WriteByte($ptr+$offset, $byte) - } - } + quote_in! { self.body => + $['\r'] + i.module.Memory().WriteByte($ptr+$offset, $byte) } } else { let tmp = self.tmp(); let byte = format!("byte{tmp}"); - match &self.direction { - Direction::Export => { - quote_in! { self.body => - $['\r'] - var $(&byte) uint8 - switch $tag { - case 0: - $(&byte) = 0 - case 1: - $(&byte) = 1 - default: - $(comment(["TODO(#8): Return an error if the return type allows it"])) - panic($ERRORS_NEW("invalid int8 value encountered")) - } - i.module.Memory().WriteByte($ptr+$offset, $byte) - } - } - Direction::Import { .. } => { - quote_in! { self.body => - $['\r'] - var $(&byte) uint8 - switch $tag { - case 0: - $(&byte) = 0 - case 1: - $(&byte) = 1 - default: - panic($ERRORS_NEW("invalid int8 value encountered")) - } - mod.Memory().WriteByte($ptr+$offset, $byte) - } + quote_in! { self.body => + $['\r'] + var $(&byte) uint8 + switch $tag { + case 0: + $(&byte) = 0 + case 1: + $(&byte) = 1 + default: + $(comment(["TODO(#8): Return an error if the return type allows it"])) + panic($ERRORS_NEW("invalid int8 value encountered")) } + i.module.Memory().WriteByte($ptr+$offset, $byte) } } } @@ -678,19 +544,10 @@ impl Bindgen for Func<'_> { let offset = offset.size_wasm32(); let tag = &operands[0]; let ptr = &operands[1]; - match &self.direction { - Direction::Export => { - quote_in! { self.body => - $['\r'] - i.module.Memory().WriteUint32Le($ptr+$offset, $tag) - } - } - Direction::Import { .. } => { - quote_in! { self.body => - $['\r'] - mod.Memory().WriteUint32Le($ptr+$offset, $tag) - } - } + + quote_in! { self.body => + $['\r'] + i.module.Memory().WriteUint32Le($ptr+$offset, $tag) } } Instruction::LengthStore { offset } => { @@ -698,19 +555,10 @@ impl Bindgen for Func<'_> { let offset = offset.size_wasm32(); let len = &operands[0]; let ptr = &operands[1]; - match &self.direction { - Direction::Export => { - quote_in! { self.body => - $['\r'] - i.module.Memory().WriteUint32Le($ptr+$offset, uint32($len)) - } - } - Direction::Import { .. } => { - quote_in! { self.body => - $['\r'] - mod.Memory().WriteUint32Le($ptr+$offset, uint32($len)) - } - } + + quote_in! { self.body => + $['\r'] + i.module.Memory().WriteUint32Le($ptr+$offset, uint32($len)) } } Instruction::PointerStore { offset } => { @@ -718,19 +566,10 @@ impl Bindgen for Func<'_> { let offset = offset.size_wasm32(); let value = &operands[0]; let ptr = &operands[1]; - match &self.direction { - Direction::Export => { - quote_in! { self.body => - $['\r'] - i.module.Memory().WriteUint32Le($ptr+$offset, uint32($value)) - } - } - Direction::Import { .. } => { - quote_in! { self.body => - $['\r'] - mod.Memory().WriteUint32Le($ptr+$offset, uint32($value)) - } - } + + quote_in! { self.body => + $['\r'] + i.module.Memory().WriteUint32Le($ptr+$offset, uint32($value)) } } Instruction::ResultLower { @@ -829,12 +668,11 @@ impl Bindgen for Func<'_> { let mut vars: Tokens = Tokens::new(); for i in 0..result_types.len() { let variant = &format!("variant{tmp}_{i}"); - let typ = resolve_wasm_type(&result_types[i]); results.push(Operand::SingleValue(variant.into())); quote_in! { vars => $['\r'] - var $variant $typ + var $variant uint64 } let some_result = &some_results[i]; @@ -1005,12 +843,11 @@ impl Bindgen for Func<'_> { let value = &operands[0]; let default = &format!("default{tmp}"); - for (i, typ) in result_types.iter().enumerate() { + for i in 0..result_types.len() { let variant_item = &format!("variant{tmp}_{i}"); - let typ = resolve_wasm_type(typ); quote_in! { self.body => $['\r'] - var $variant_item $typ + var $variant_item uint64 } results.push(Operand::SingleValue(variant_item.into())); } @@ -1073,7 +910,7 @@ impl Bindgen for Func<'_> { quote_in! { self.body => $['\r'] - var $enum_tmp uint32 + var $enum_tmp uint64 switch $value { $cases default: diff --git a/cmd/gravity/src/codegen/imported_func.rs b/cmd/gravity/src/codegen/imported_func.rs new file mode 100644 index 0000000..160891f --- /dev/null +++ b/cmd/gravity/src/codegen/imported_func.rs @@ -0,0 +1,983 @@ +use std::mem; + +use genco::prelude::*; +use wit_bindgen_core::{ + abi::{Bindgen, Instruction}, + wit_parser::{Alignment, ArchitectureSize, Resolve, Result_, SizeAlign, Type}, +}; + +use crate::{ + go::{ + GoIdentifier, GoType, Operand, comment, + imports::{ + ERRORS_NEW, REFLECT_VALUE_OF, WAZERO_API_DECODE_F32, WAZERO_API_DECODE_F64, + WAZERO_API_DECODE_I32, WAZERO_API_DECODE_U32, WAZERO_API_ENCODE_F32, + WAZERO_API_ENCODE_F64, WAZERO_API_ENCODE_I32, WAZERO_API_ENCODE_U32, + }, + }, + resolve_type, +}; + +pub struct ImportedFunc<'a> { + param_name: &'a GoIdentifier, + tmp: usize, + body: Tokens, + block_storage: Vec>, + blocks: Vec<(Tokens, Vec)>, + sizes: &'a SizeAlign, +} + +impl<'a> ImportedFunc<'a> { + /// Create a new imported function. + pub fn new(param_name: &'a GoIdentifier, sizes: &'a SizeAlign) -> Self { + Self { + param_name, + tmp: 0, + body: Tokens::new(), + block_storage: Vec::new(), + blocks: Vec::new(), + sizes, + } + } + + fn tmp(&mut self) -> usize { + let ret = self.tmp; + self.tmp += 1; + ret + } + + pub fn body(&self) -> &Tokens { + &self.body + } + + fn pop_block(&mut self) -> (Tokens, Vec) { + self.blocks.pop().expect("should have block to pop") + } +} + +impl Bindgen for ImportedFunc<'_> { + type Operand = Operand; + + fn emit( + &mut self, + resolve: &Resolve, + inst: &Instruction<'_>, + operands: &mut Vec, + results: &mut Vec, + ) { + let iter_element = "e"; + let iter_base = "base"; + + match inst { + Instruction::GetArg { nth } => { + let arg = &format!("arg{nth}"); + let idx = *nth; + quote_in! { self.body => + $['\r'] + $arg := stack[$idx] + }; + results.push(Operand::SingleValue(arg.into())); + } + Instruction::ConstZero { tys } => { + for _ in tys.iter() { + results.push(Operand::Literal("0".into())) + } + } + Instruction::StringLower { realloc: None } => todo!("implement instruction: {inst:?}"), + Instruction::StringLower { + realloc: Some(realloc_name), + } => { + let tmp = self.tmp(); + let ptr = &format!("ptr{tmp}"); + let len = &format!("len{tmp}"); + let err = &format!("err{tmp}"); + let memory = &format!("memory{tmp}"); + let realloc = &format!("realloc{tmp}"); + let operand = &operands[0]; + + quote_in! { self.body => + $['\r'] + $memory := mod.Memory() + $realloc := mod.ExportedFunction($(quoted(*realloc_name))) + $ptr, $len, $err := writeString(ctx, $operand, $memory, $realloc) + if $err != nil { + panic($err) + } + }; + results.push(Operand::SingleValue(ptr.into())); + results.push(Operand::SingleValue(len.into())); + } + Instruction::CallWasm { name, .. } => { + let tmp = self.tmp(); + let err = &format!("err{tmp}"); + // TODO(#17): Wrapping every argument in `uint64` is bad and we should instead be looking + // at the types and converting with proper guards in place + quote_in! { self.body => + $['\r'] + _, $err := i.module.ExportedFunction($(quoted(*name))).Call(ctx, $(for op in operands.iter() join (, ) => uint64($op))) + $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) + if $err != nil { + panic($err) + } + }; + } + Instruction::I32Load8U { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tmp = self.tmp(); + let value = &format!("value{tmp}"); + let ok = &format!("ok{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $value, $ok := i.module.Memory().ReadByte(uint32($operand + $offset)) + $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) + if !$ok { + panic($ERRORS_NEW("failed to read byte from memory")) + } + }; + results.push(Operand::SingleValue(value.into())); + } + Instruction::I32FromBool => { + let tmp = self.tmp(); + let value = format!("value{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + var $(&value) uint64 + if $operand { + $(&value) = 1 + } else { + $(&value) = 0 + } + } + results.push(Operand::SingleValue(value)) + } + Instruction::BoolFromI32 => { + let tmp = self.tmp(); + let value = format!("value{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $(&value) := $operand != 0 + } + results.push(Operand::SingleValue(value)) + } + Instruction::I32FromU32 => { + let tmp = self.tmp(); + let result = &format!("result{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $result := $WAZERO_API_ENCODE_U32($operand) + }; + results.push(Operand::SingleValue(result.into())); + } + Instruction::U32FromI32 => { + let tmp = self.tmp(); + let result = &format!("result{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $result := $WAZERO_API_DECODE_U32($operand) + }; + results.push(Operand::SingleValue(result.into())); + } + Instruction::PointerLoad { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tmp = self.tmp(); + let ptr = &format!("ptr{tmp}"); + let ok = &format!("ok{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $ptr, $ok := i.module.Memory().ReadUint32Le(uint32($operand + $offset)) + $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) + if !$ok { + panic($ERRORS_NEW("failed to read pointer from memory")) + } + }; + results.push(Operand::SingleValue(ptr.into())); + } + Instruction::LengthLoad { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tmp = self.tmp(); + let len = &format!("len{tmp}"); + let ok = &format!("ok{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $len, $ok := i.module.Memory().ReadUint32Le(uint32($operand + $offset)) + $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) + if !$ok { + panic($ERRORS_NEW("failed to read length from memory")) + } + }; + results.push(Operand::SingleValue(len.into())); + } + Instruction::I32Load { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tmp = self.tmp(); + let value = &format!("value{tmp}"); + let ok = &format!("ok{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $value, $ok := i.module.Memory().ReadUint32Le(uint32($operand + $offset)) + $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) + if !$ok { + panic($ERRORS_NEW("failed to read i32 from memory")) + } + }; + results.push(Operand::SingleValue(value.into())); + } + Instruction::StringLift => { + let tmp = self.tmp(); + let buf = &format!("buf{tmp}"); + let ok = &format!("ok{tmp}"); + let str = &format!("str{tmp}"); + let ptr = &format!("ptr{tmp}"); + let len = &format!("len{tmp}"); + let ptr_op = &operands[0]; + let len_op = &operands[1]; + + quote_in! { self.body => + $['\r'] + $ptr := $WAZERO_API_DECODE_U32($ptr_op) + $len := $WAZERO_API_DECODE_U32($len_op) + $buf, $ok := mod.Memory().Read($ptr, $len) + if !$ok { + panic($ERRORS_NEW("failed to read bytes from memory")) + } + $str := string($buf) + }; + results.push(Operand::SingleValue(str.into())); + } + Instruction::ResultLift { + result: + Result_ { + ok: Some(typ), + err: Some(Type::String), + }, + .. + } => { + let (err_block, err_results) = self.pop_block(); + assert_eq!(err_results.len(), 1); + let err_op = &err_results[0]; + + let (ok_block, ok_results) = self.pop_block(); + assert_eq!(ok_results.len(), 1); + let ok_op = &ok_results[0]; + + let tmp = self.tmp(); + let value = &format!("value{tmp}"); + let err = &format!("err{tmp}"); + let typ = resolve_type(typ, resolve); + let tag = &operands[0]; + quote_in! { self.body => + $['\r'] + var $value $typ + var $err error + switch $tag { + case 0: + $ok_block + $value = $ok_op + case 1: + $err_block + $err = $ERRORS_NEW($err_op) + default: + $err = $ERRORS_NEW("invalid variant discriminant for expected") + } + }; + + results.push(Operand::MultiValue((value.into(), err.into()))); + } + Instruction::ResultLift { + result: + Result_ { + ok: None, + err: Some(Type::String), + }, + .. + } => { + let (err_block, err_results) = self.pop_block(); + assert_eq!(err_results.len(), 1); + let err_op = &err_results[0]; + + let (ok_block, ok_results) = self.pop_block(); + assert_eq!(ok_results.len(), 0); + + let tmp = self.tmp(); + let err = &format!("err{tmp}"); + let tag = &operands[0]; + quote_in! { self.body => + $['\r'] + var $err error + switch $tag { + case 0: + $ok_block + case 1: + $err_block + $err = $ERRORS_NEW($err_op) + default: + $err = $ERRORS_NEW("invalid variant discriminant for expected") + } + }; + + results.push(Operand::SingleValue(err.into())); + } + Instruction::ResultLift { .. } => todo!("implement instruction: {inst:?}"), + Instruction::Return { amt, .. } => { + for idx in 0..*amt { + let operand = &operands[idx]; + quote_in! { self.body => + $['\r'] + stack[$idx] = $operand + }; + } + } + Instruction::CallInterface { func, .. } => { + let ident = GoIdentifier::public(&func.name); + let tmp = self.tmp(); + let args = quote!($(for op in operands.iter() join (, ) => $op)); + let returns = match &func.result { + None => GoType::Nothing, + Some(typ) => resolve_type(typ, resolve), + }; + let value = &format!("value{tmp}"); + let err = &format!("err{tmp}"); + let ok = &format!("ok{tmp}"); + let param_name = self.param_name; + + quote_in! { self.body => + $['\r'] + $(match returns { + GoType::Nothing => $param_name.$ident(ctx, $args), + GoType::Bool | GoType::Uint32 | GoType::Interface | GoType::String | GoType::UserDefined(_) => $value := $param_name.$ident(ctx, $args), + GoType::Error => $err := $param_name.$ident(ctx, $args), + GoType::ValueOrError(_) => { + $value, $err := $param_name.$ident(ctx, $args) + } + GoType::ValueOrOk(_) => { + $value, $ok := $param_name.$ident(ctx, $args) + } + _ => $(comment(&["TODO(#9): handle return type"])) + }) + } + match returns { + GoType::Nothing => (), + GoType::Bool + | GoType::Uint32 + | GoType::Interface + | GoType::UserDefined(_) + | GoType::String => { + results.push(Operand::SingleValue(value.into())); + } + GoType::Error => { + results.push(Operand::SingleValue(err.into())); + } + GoType::ValueOrError(_) => { + results.push(Operand::MultiValue((value.into(), err.into()))); + } + GoType::ValueOrOk(_) => { + results.push(Operand::MultiValue((value.into(), ok.into()))) + } + _ => todo!("TODO(#9): handle return type - {returns:?}"), + } + } + Instruction::VariantPayloadName => { + results.push(Operand::SingleValue("variantPayload".into())); + } + Instruction::I32Const { val } => results.push(Operand::Literal(val.to_string())), + Instruction::I32Store8 { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tmp = self.tmp(); + let tag = &operands[0]; + let ptr = &format!("ptr{tmp}"); + let ptr_op = &operands[1]; + if let Operand::Literal(byte) = tag { + quote_in! { self.body => + $['\r'] + $ptr := $WAZERO_API_DECODE_U32($ptr_op) + mod.Memory().WriteByte($ptr+$offset, $byte) + } + } else { + let tmp = self.tmp(); + let byte = &format!("byte{tmp}"); + + quote_in! { self.body => + $['\r'] + var $byte uint8 + switch $tag { + case 0: + $byte = 0 + case 1: + $byte = 1 + default: + panic($ERRORS_NEW("invalid int8 value encountered")) + } + $ptr := $WAZERO_API_DECODE_U32($ptr_op) + mod.Memory().WriteByte($ptr+$offset, $byte) + } + } + } + Instruction::I32Store { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tmp = self.tmp(); + let tag = &operands[0]; + let ptr = &format!("ptr{tmp}"); + let ptr_op = &operands[1]; + + quote_in! { self.body => + $['\r'] + $ptr := $WAZERO_API_DECODE_U32($ptr_op) + mod.Memory().WriteUint32Le($ptr+$offset, uint32($tag)) + } + } + Instruction::LengthStore { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tmp = self.tmp(); + let len = &operands[0]; + let ptr = &format!("ptr{tmp}"); + let ptr_op = &operands[1]; + quote_in! { self.body => + $['\r'] + $ptr := $WAZERO_API_DECODE_U32($ptr_op) + mod.Memory().WriteUint32Le($ptr+$offset, uint32($len)) + } + } + Instruction::PointerStore { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tmp = self.tmp(); + let value = &operands[0]; + let ptr = &format!("ptr{tmp}"); + let ptr_op = &operands[1]; + quote_in! { self.body => + $['\r'] + $ptr := $WAZERO_API_DECODE_U32($ptr_op) + mod.Memory().WriteUint32Le($ptr+$offset, uint32($value)) + } + } + Instruction::ResultLower { + result: + Result_ { + ok: Some(_), + err: Some(Type::String), + }, + .. + } => { + let (err_block, _) = self.pop_block(); + let (ok_block, _) = self.pop_block(); + let operand = &operands[0]; + let (ok, err) = match operand { + Operand::Literal(_) => { + panic!("impossible: expected Operand::MultiValue but got Operand::Literal") + } + Operand::SingleValue(_) => panic!( + "impossible: expected Operand::MultiValue but got Operand::SingleValue" + ), + Operand::MultiValue(bindings) => bindings, + }; + quote_in! { self.body => + $['\r'] + if $err != nil { + variantPayload := $err.Error() + $err_block + } else { + variantPayload := $ok + $ok_block + } + }; + } + Instruction::ResultLower { + result: + Result_ { + ok: None, + err: Some(Type::String), + }, + .. + } => { + let (err, _) = self.pop_block(); + let (ok, _) = self.pop_block(); + let err_result = &operands[0]; + quote_in! { self.body => + $['\r'] + if $err_result != nil { + variantPayload := $err_result.Error() + $err + } else { + $ok + } + }; + } + Instruction::ResultLower { .. } => todo!("implement instruction: {inst:?}"), + Instruction::OptionLift { payload, .. } => { + let (some, some_results) = self.blocks.pop().unwrap(); + let (none, _) = self.blocks.pop().unwrap(); + let some_result = &some_results[0]; + + let tmp = self.tmp(); + let result = &format!("result{tmp}"); + let ok = &format!("ok{tmp}"); + let typ = resolve_type(payload, resolve); + let op = &operands[0]; + + quote_in! { self.body => + $['\r'] + var $result $typ + var $ok bool + if $op == 0 { + $none + $ok = false + } else { + $some + $ok = true + $result = $some_result + } + }; + + results.push(Operand::MultiValue((result.into(), ok.into()))); + } + Instruction::OptionLower { + results: result_types, + .. + } => { + let (mut some_block, some_results) = self.pop_block(); + let (mut none_block, none_results) = self.pop_block(); + + let tmp = self.tmp(); + + // If there are no result_types, then the payload will be a pointer, + // because that's how we represent optionals in Go. + let is_pointer = result_types.is_empty(); + + let mut vars: Tokens = Tokens::new(); + for i in 0..result_types.len() { + let variant = &format!("variant{tmp}_{i}"); + results.push(Operand::SingleValue(variant.into())); + + quote_in! { vars => + $['\r'] + var $variant uint64 + } + + let some_result = &some_results[i]; + let none_result = &none_results[i]; + quote_in! { some_block => + $['\r'] + $variant = $some_result + }; + quote_in! { none_block => + $['\r'] + $variant = $none_result + }; + } + + let operand = &operands[0]; + match operand { + Operand::Literal(_) => { + panic!("impossible: expected Operand::MultiValue but got Operand::Literal") + } + Operand::SingleValue(value) => { + quote_in! { self.body => + $['\r'] + $vars + if $REFLECT_VALUE_OF($value).IsZero() { + $none_block + } else { + variantPayload := $(if is_pointer => *)$value + $some_block + } + }; + } + Operand::MultiValue((value, ok)) => { + quote_in! { self.body => + $['\r'] + if $ok { + variantPayload := $value + $some_block + } else { + $none_block + } + }; + } + }; + } + Instruction::RecordLower { record, .. } => { + let tmp = self.tmp(); + let operand = &operands[0]; + for field in record.fields.iter() { + let struct_field = GoIdentifier::public(&field.name); + let var = &GoIdentifier::local(format!("{}{tmp}", &field.name)); + quote_in! { self.body => + $['\r'] + $var := $operand.$struct_field + } + results.push(Operand::SingleValue(var.into())) + } + } + Instruction::RecordLift { record, name, .. } => { + let tmp = self.tmp(); + let value = &format!("value{tmp}"); + let fields = record + .fields + .iter() + .zip(operands) + .map(|(field, op)| (GoIdentifier::public(&field.name), op)); + + quote_in! {self.body => + $['\r'] + $value := $(GoIdentifier::public(*name)){ + $(for (name, op) in fields join ($['\r']) => $name: $op,) + } + }; + results.push(Operand::SingleValue(value.into())) + } + Instruction::IterElem { .. } => results.push(Operand::SingleValue(iter_element.into())), + Instruction::IterBasePointer => results.push(Operand::SingleValue(iter_base.into())), + Instruction::ListLower { realloc: None, .. } => { + todo!("implement instruction: {inst:?}") + } + Instruction::ListLower { + element, + realloc: Some(realloc_name), + } => { + let (body, _) = self.pop_block(); + let tmp = self.tmp(); + let vec = &format!("vec{tmp}"); + let result = &format!("result{tmp}"); + let err = &format!("err{tmp}"); + let ptr = &format!("ptr{tmp}"); + let len = &format!("len{tmp}"); + let operand = &operands[0]; + let size = self.sizes.size(element).size_wasm32(); + let align = self.sizes.align(element).align_wasm32(); + + quote_in! { self.body => + $['\r'] + $vec := $operand + $len := uint64(len($vec)) + $result, $err := i.module.ExportedFunction($(quoted(*realloc_name))).Call(ctx, 0, 0, $align, $len * $size) + $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) + if $err != nil { + panic($err) + } + $ptr := $result[0] + for idx := uint64(0); idx < $len; idx++ { + $iter_element := $vec[idx] + $iter_base := uint32($ptr + uint64(idx) * uint64($size)) + $body + } + }; + results.push(Operand::SingleValue(ptr.into())); + results.push(Operand::SingleValue(len.into())); + } + Instruction::ListLift { element, .. } => { + let (body, body_results) = self.pop_block(); + let tmp = self.tmp(); + let size = self.sizes.size(element).size_wasm32(); + let len = &format!("len{tmp}"); + let base = &format!("base{tmp}"); + let result = &format!("result{tmp}"); + let idx = &format!("idx{tmp}"); + + let base_operand = &operands[0]; + let len_operand = &operands[1]; + let body_result = &body_results[0]; + + let typ = resolve_type(element, resolve); + + quote_in! { self.body => + $['\r'] + $base := $base_operand + $len := $len_operand + $result := make([]$typ, $len) + for $idx := uint32(0); $idx < $len; $idx++ { + base := $base + $idx * $size + $body + $result[$idx] = $body_result + } + } + results.push(Operand::SingleValue(result.into())); + } + Instruction::VariantLower { + variant, + results: result_types, + .. + } => { + let blocks = self + .blocks + .drain(self.blocks.len() - variant.cases.len()..) + .collect::>(); + let tmp = self.tmp(); + let value = &operands[0]; + + for i in 0..result_types.len() { + let variant_item = &format!("variant{tmp}_{i}"); + quote_in! { self.body => + $['\r'] + var $variant_item uint64 + } + results.push(Operand::SingleValue(variant_item.into())); + } + + let mut cases: Tokens = Tokens::new(); + for (case, (block, block_results)) in variant.cases.iter().zip(blocks) { + let mut assignments: Tokens = Tokens::new(); + for (i, result) in block_results.iter().enumerate() { + let variant_item = &format!("variant{tmp}_{i}"); + quote_in! { assignments => + $['\r'] + $variant_item = $result + }; + } + + let name = GoIdentifier::public(case.name.clone()); + quote_in! { cases => + $['\r'] + case $name: + $block + $assignments + } + } + + quote_in! { self.body => + $['\r'] + switch variantPayload := $value.(type) { + $cases + default: + $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) + panic($ERRORS_NEW("invalid variant type provided")) + } + } + } + Instruction::EnumLower { enum_, .. } => { + let value = &operands[0]; + let tmp = self.tmp(); + let enum_tmp = &format!("enum{tmp}"); + + let mut cases: Tokens = Tokens::new(); + for (i, case) in enum_.cases.iter().enumerate() { + let case_name = GoIdentifier::public(case.name.clone()); + quote_in! { cases => + $['\r'] + case $case_name: + $enum_tmp = $i + }; + } + + quote_in! { self.body => + $['\r'] + var $enum_tmp uint64 + switch $value { + $cases + default: + panic($ERRORS_NEW("invalid enum type provided")) + } + }; + + results.push(Operand::SingleValue(enum_tmp.to_string())); + } + Instruction::Bitcasts { .. } => todo!("implement instruction: {inst:?}"), + Instruction::I32Load8S { .. } => todo!("implement instruction: {inst:?}"), + Instruction::I32Load16U { .. } => todo!("implement instruction: {inst:?}"), + Instruction::I32Load16S { .. } => todo!("implement instruction: {inst:?}"), + Instruction::I64Load { .. } => todo!("implement instruction: {inst:?}"), + Instruction::F32Load { .. } => todo!("implement instruction: {inst:?}"), + Instruction::F64Load { .. } => todo!("implement instruction: {inst:?}"), + Instruction::I32Store16 { .. } => todo!("implement instruction: {inst:?}"), + Instruction::I64Store { .. } => todo!("implement instruction: {inst:?}"), + Instruction::F32Store { .. } => todo!("implement instruction: {inst:?}"), + Instruction::F64Store { .. } => todo!("implement instruction: {inst:?}"), + Instruction::I32FromChar => todo!("implement instruction: {inst:?}"), + Instruction::I64FromU64 => todo!("implement instruction: {inst:?}"), + Instruction::I64FromS64 => todo!("implement instruction: {inst:?}"), + Instruction::I32FromS32 => { + let tmp = self.tmp(); + let value = format!("value{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $(&value) := $WAZERO_API_ENCODE_I32($operand) + } + results.push(Operand::SingleValue(value)) + } + // All of these values should fit in Go's `int32` type which allows a safe cast + Instruction::I32FromU16 + | Instruction::I32FromS16 + | Instruction::I32FromU8 + | Instruction::I32FromS8 => { + let tmp = self.tmp(); + let value = format!("value{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $(&value) := $WAZERO_API_ENCODE_I32(int32($operand)) + } + results.push(Operand::SingleValue(value)) + } + Instruction::CoreF32FromF32 => { + let tmp = self.tmp(); + let result = &format!("result{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $result := $WAZERO_API_ENCODE_F32($operand) + }; + results.push(Operand::SingleValue(result.into())); + } + Instruction::CoreF64FromF64 => { + let tmp = self.tmp(); + let result = &format!("result{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $result := $WAZERO_API_ENCODE_F64($operand) + }; + results.push(Operand::SingleValue(result.into())); + } + // TODO: Validate the Go cast truncates the upper bits in the I32 + Instruction::S8FromI32 => { + let tmp = self.tmp(); + let result = &format!("result{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $result := int8($WAZERO_API_DECODE_I32($operand)) + }; + results.push(Operand::SingleValue(result.into())); + } + // TODO: Validate the Go cast truncates the upper bits in the I32 + Instruction::U8FromI32 => { + let tmp = self.tmp(); + let result = &format!("result{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $result := uint8($WAZERO_API_DECODE_U32($operand)) + }; + results.push(Operand::SingleValue(result.into())); + } + // TODO: Validate the Go cast truncates the upper bits in the I32 + Instruction::S16FromI32 => { + let tmp = self.tmp(); + let result = &format!("result{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $result := int16($WAZERO_API_DECODE_I32($operand)) + }; + results.push(Operand::SingleValue(result.into())); + } + // TODO: Validate the Go cast truncates the upper bits in the I32 + Instruction::U16FromI32 => { + let tmp = self.tmp(); + let result = &format!("result{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $result := uint16($WAZERO_API_DECODE_U32($operand)) + }; + results.push(Operand::SingleValue(result.into())); + } + Instruction::S32FromI32 => { + let tmp = self.tmp(); + let result = &format!("result{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $result := $WAZERO_API_DECODE_I32($operand) + }; + results.push(Operand::SingleValue(result.into())); + } + Instruction::S64FromI64 => todo!("implement instruction: {inst:?}"), + Instruction::U64FromI64 => todo!("implement instruction: {inst:?}"), + Instruction::CharFromI32 => todo!("implement instruction: {inst:?}"), + Instruction::F32FromCoreF32 => { + let tmp = self.tmp(); + let result = &format!("result{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $result := $WAZERO_API_DECODE_F32($operand) + }; + results.push(Operand::SingleValue(result.into())); + } + Instruction::F64FromCoreF64 => { + let tmp = self.tmp(); + let result = &format!("result{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $result := $WAZERO_API_DECODE_F64($operand) + }; + results.push(Operand::SingleValue(result.into())); + } + Instruction::TupleLower { .. } => todo!("implement instruction: {inst:?}"), + Instruction::TupleLift { .. } => todo!("implement instruction: {inst:?}"), + Instruction::FlagsLower { .. } => todo!("implement instruction: {inst:?}"), + Instruction::FlagsLift { .. } => todo!("implement instruction: {inst:?}"), + Instruction::VariantLift { .. } => { + todo!("implement instruction: {inst:?}") + } + Instruction::EnumLift { .. } => todo!("implement instruction: {inst:?}"), + Instruction::Malloc { .. } => todo!("implement instruction: {inst:?}"), + Instruction::HandleLower { .. } | Instruction::HandleLift { .. } => { + todo!("implement resources: {inst:?}") + } + Instruction::ListCanonLower { .. } | Instruction::ListCanonLift { .. } => { + unimplemented!("gravity doesn't represent lists as Canonical") + } + Instruction::GuestDeallocateString + | Instruction::GuestDeallocate { .. } + | Instruction::GuestDeallocateList { .. } + | Instruction::GuestDeallocateVariant { .. } => { + unimplemented!("gravity doesn't generate the Guest code") + } + Instruction::FutureLower { .. } => todo!("implement instruction: {inst:?}"), + Instruction::FutureLift { .. } => todo!("implement instruction: {inst:?}"), + Instruction::StreamLower { .. } => todo!("implement instruction: {inst:?}"), + Instruction::StreamLift { .. } => todo!("implement instruction: {inst:?}"), + Instruction::ErrorContextLower => todo!("implement instruction: {inst:?}"), + Instruction::ErrorContextLift => todo!("implement instruction: {inst:?}"), + Instruction::AsyncTaskReturn { .. } => todo!("implement instruction: {inst:?}"), + Instruction::DropHandle { .. } => todo!("implement instruction: {inst:?}"), + Instruction::Flush { amt } => { + for op in operands.iter().take(*amt) { + results.push(op.clone()); + } + } + } + } + + fn return_pointer(&mut self, _size: ArchitectureSize, _align: Alignment) -> Self::Operand { + unimplemented!("return_pointer") + } + + fn push_block(&mut self) { + let prev = mem::replace(&mut self.body, Tokens::new()); + self.block_storage.push(prev); + } + + fn finish_block(&mut self, operands: &mut Vec) { + let to_restore = self.block_storage.pop().expect("should have body"); + let src = mem::replace(&mut self.body, to_restore); + self.blocks.push((src, mem::take(operands))); + } + + fn sizes(&self) -> &SizeAlign { + self.sizes + } + + fn is_list_canonical(&self, _resolve: &Resolve, _element: &Type) -> bool { + // Go slices are never directly in the Wasm Memory, so they are never "canonical" + false + } +} diff --git a/cmd/gravity/src/codegen/imports.rs b/cmd/gravity/src/codegen/imports.rs index 8f828c1..14e913d 100644 --- a/cmd/gravity/src/codegen/imports.rs +++ b/cmd/gravity/src/codegen/imports.rs @@ -10,7 +10,7 @@ use wit_bindgen_core::{ use crate::{ codegen::{ - func::Func, + ImportedFunc, ir::{ AnalyzedFunction, AnalyzedImports, AnalyzedInterface, AnalyzedType, InterfaceMethod, Parameter, TypeDefinition, WitReturn, @@ -18,7 +18,10 @@ use crate::{ }, go::{ GoIdentifier, GoResult, GoType, - imports::{CONTEXT_CONTEXT, WAZERO_API_MODULE}, + imports::{ + CONTEXT_CONTEXT, GoImport, WAZERO_API_GO_MODULE_FUNC, WAZERO_API_MODULE, + WAZERO_API_VALUE_TYPE, + }, }, resolve_type, }; @@ -414,21 +417,10 @@ impl<'a> ImportCodeGenerator<'a> { ) -> Tokens { let func_name = &method.name; - // Generate Wasm function parameters based on WIT types. - let wasm_params = vec![ - quote! { ctx $CONTEXT_CONTEXT }, - quote! { mod $WAZERO_API_MODULE }, - ]; - let wasm_sig = self .resolve .wasm_signature(AbiVariant::GuestImport, &method.wit_function); - let result = if wasm_sig.results.is_empty() { - GoResult::Empty - } else { - todo!("implement handling of wasm signatures with results"); - }; - let mut f = Func::import(param_name, result, self.sizes); + let mut f = ImportedFunc::new(param_name, self.sizes); // Magic wit_bindgen_core::abi::call( @@ -441,14 +433,26 @@ impl<'a> ImportCodeGenerator<'a> { false, ); + let wasm_params = wasm_sig.params.iter().map(GoImport::from); + let wasm_results = wasm_sig.results.iter().map(GoImport::from); + quote! { NewFunctionBuilder(). - WithFunc(func( - $(for param in wasm_params join (,$['\r']) => $param), - $(for param in f.args() join (,$['\r']) => $param uint32), - ) $(f.result()){ - $(f.body()) - }). + WithGoModuleFunction( + $WAZERO_API_GO_MODULE_FUNC(func( + ctx $CONTEXT_CONTEXT, + mod $WAZERO_API_MODULE, + stack []uint64, + ) { + $(f.body()) + }), + []$WAZERO_API_VALUE_TYPE{ + $(for typ in wasm_params join ($['\r']) => $typ,) + }, + []$WAZERO_API_VALUE_TYPE{ + $(for typ in wasm_results join ($['\r']) => $typ,) + }, + ). Export($(quoted(func_name))). } } @@ -564,7 +568,7 @@ mod tests { // Should have only one uint32 parameter (plus ctx and mod) let code_str = result.to_string().unwrap(); - assert!(code_str.contains("arg0 uint32")); + assert!(code_str.contains("arg0 := stack[0]")); assert!(!code_str.contains("arg1 uint32")); assert!(!code_str.contains("mod.Memory().Read")); // No string reading diff --git a/cmd/gravity/src/codegen/mod.rs b/cmd/gravity/src/codegen/mod.rs index 3bb7433..04ae959 100644 --- a/cmd/gravity/src/codegen/mod.rs +++ b/cmd/gravity/src/codegen/mod.rs @@ -2,6 +2,7 @@ mod bindings; mod exports; mod factory; mod func; +mod imported_func; mod imports; mod ir; mod wasm; @@ -10,4 +11,5 @@ pub use bindings::*; pub use exports::ExportGenerator; pub use factory::FactoryGenerator; pub use func::Func; +pub use imported_func::ImportedFunc; pub use wasm::WasmData; diff --git a/cmd/gravity/src/go/imports.rs b/cmd/gravity/src/go/imports.rs index 53ea3f2..41e3f1e 100644 --- a/cmd/gravity/src/go/imports.rs +++ b/cmd/gravity/src/go/imports.rs @@ -1,4 +1,5 @@ use genco::{Tokens, lang::Go, tokens::FormatInto}; +use wit_bindgen_core::abi::WasmType; #[derive(Debug, Clone, Copy)] pub struct GoImport(&'static str, &'static str); @@ -9,6 +10,21 @@ impl FormatInto for GoImport { } } +impl From<&WasmType> for GoImport { + fn from(typ: &WasmType) -> Self { + match typ { + WasmType::I32 => WAZERO_API_VALUE_TYPE_I32, + WasmType::I64 => WAZERO_API_VALUE_TYPE_I64, + WasmType::F32 => WAZERO_API_VALUE_TYPE_F32, + WasmType::F64 => WAZERO_API_VALUE_TYPE_F64, + // TODO: Verify that Gravity/Wazero "doesn't do anything special" and can treat these as such + WasmType::Pointer => WAZERO_API_VALUE_TYPE_I32, + WasmType::PointerOrI64 => WAZERO_API_VALUE_TYPE_I64, + WasmType::Length => WAZERO_API_VALUE_TYPE_I32, + } + } +} + pub static CONTEXT_CONTEXT: GoImport = GoImport("context", "Context"); pub static ERRORS_NEW: GoImport = GoImport("errors", "New"); pub static FMT_PRINTF: GoImport = GoImport("fmt", "Printf"); @@ -36,4 +52,16 @@ pub static WAZERO_API_ENCODE_F64: GoImport = GoImport("github.com/tetratelabs/wazero/api", "EncodeF64"); pub static WAZERO_API_DECODE_F64: GoImport = GoImport("github.com/tetratelabs/wazero/api", "DecodeF64"); +pub static WAZERO_API_VALUE_TYPE: GoImport = + GoImport("github.com/tetratelabs/wazero/api", "ValueType"); +pub static WAZERO_API_VALUE_TYPE_I32: GoImport = + GoImport("github.com/tetratelabs/wazero/api", "ValueTypeI32"); +pub static WAZERO_API_VALUE_TYPE_I64: GoImport = + GoImport("github.com/tetratelabs/wazero/api", "ValueTypeI64"); +pub static WAZERO_API_VALUE_TYPE_F32: GoImport = + GoImport("github.com/tetratelabs/wazero/api", "ValueTypeF32"); +pub static WAZERO_API_VALUE_TYPE_F64: GoImport = + GoImport("github.com/tetratelabs/wazero/api", "ValueTypeF64"); +pub static WAZERO_API_GO_MODULE_FUNC: GoImport = + GoImport("github.com/tetratelabs/wazero/api", "GoModuleFunc"); pub static REFLECT_VALUE_OF: GoImport = GoImport("reflect", "ValueOf"); diff --git a/cmd/gravity/src/lib.rs b/cmd/gravity/src/lib.rs index 0b6cc2b..0866592 100644 --- a/cmd/gravity/src/lib.rs +++ b/cmd/gravity/src/lib.rs @@ -2,27 +2,11 @@ pub mod codegen; pub mod go; use crate::go::GoType; -use wit_bindgen_core::{ - abi::WasmType, - wit_parser::{Resolve, Result_, Type, TypeDef, TypeDefKind}, -}; +use wit_bindgen_core::wit_parser::{Resolve, Result_, Type, TypeDef, TypeDefKind}; // Temporary re-export while we migrate. pub use codegen::Func; -/// Resolves a Wasm type to a Go type. -pub fn resolve_wasm_type(typ: &WasmType) -> GoType { - match typ { - WasmType::I32 => GoType::Uint32, - WasmType::I64 => GoType::Uint64, - WasmType::F32 => GoType::Float32, - WasmType::F64 => GoType::Float64, - WasmType::Pointer => GoType::Uint64, - WasmType::PointerOrI64 => GoType::Uint64, - WasmType::Length => GoType::Uint64, - } -} - /// Resolves a WIT type to a Go type. /// /// # Panics diff --git a/cmd/gravity/tests/cmd/basic.stdout b/cmd/gravity/tests/cmd/basic.stdout index 5fd7413..01f3c3a 100644 --- a/cmd/gravity/tests/cmd/basic.stdout +++ b/cmd/gravity/tests/cmd/basic.stdout @@ -58,64 +58,104 @@ func NewBasicFactory( _, err0 := wazeroRuntime.NewHostModuleBuilder("arcjet:basic/logger"). NewFunctionBuilder(). - WithFunc(func( - ctx context.Context, - mod api.Module, - arg0 uint32, - arg1 uint32, - ) { - buf0, ok0 := mod.Memory().Read(arg0, arg1) - if !ok0 { - panic(errors.New("failed to read bytes from memory")) - } - str0 := string(buf0) - logger.Debug(ctx, str0) - }). + WithGoModuleFunction( + api.GoModuleFunc(func( + ctx context.Context, + mod api.Module, + stack []uint64, + ) { + arg0 := stack[0] + arg1 := stack[1] + ptr0 := api.DecodeU32(arg0) + len0 := api.DecodeU32(arg1) + buf0, ok0 := mod.Memory().Read(ptr0, len0) + if !ok0 { + panic(errors.New("failed to read bytes from memory")) + } + str0 := string(buf0) + logger.Debug(ctx, str0) + }), + []api.ValueType{ + api.ValueTypeI32, + api.ValueTypeI32, + }, + []api.ValueType{}, + ). Export("debug"). NewFunctionBuilder(). - WithFunc(func( - ctx context.Context, - mod api.Module, - arg0 uint32, - arg1 uint32, - ) { - buf0, ok0 := mod.Memory().Read(arg0, arg1) - if !ok0 { - panic(errors.New("failed to read bytes from memory")) - } - str0 := string(buf0) - logger.Info(ctx, str0) - }). + WithGoModuleFunction( + api.GoModuleFunc(func( + ctx context.Context, + mod api.Module, + stack []uint64, + ) { + arg0 := stack[0] + arg1 := stack[1] + ptr0 := api.DecodeU32(arg0) + len0 := api.DecodeU32(arg1) + buf0, ok0 := mod.Memory().Read(ptr0, len0) + if !ok0 { + panic(errors.New("failed to read bytes from memory")) + } + str0 := string(buf0) + logger.Info(ctx, str0) + }), + []api.ValueType{ + api.ValueTypeI32, + api.ValueTypeI32, + }, + []api.ValueType{}, + ). Export("info"). NewFunctionBuilder(). - WithFunc(func( - ctx context.Context, - mod api.Module, - arg0 uint32, - arg1 uint32, - ) { - buf0, ok0 := mod.Memory().Read(arg0, arg1) - if !ok0 { - panic(errors.New("failed to read bytes from memory")) - } - str0 := string(buf0) - logger.Warn(ctx, str0) - }). + WithGoModuleFunction( + api.GoModuleFunc(func( + ctx context.Context, + mod api.Module, + stack []uint64, + ) { + arg0 := stack[0] + arg1 := stack[1] + ptr0 := api.DecodeU32(arg0) + len0 := api.DecodeU32(arg1) + buf0, ok0 := mod.Memory().Read(ptr0, len0) + if !ok0 { + panic(errors.New("failed to read bytes from memory")) + } + str0 := string(buf0) + logger.Warn(ctx, str0) + }), + []api.ValueType{ + api.ValueTypeI32, + api.ValueTypeI32, + }, + []api.ValueType{}, + ). Export("warn"). NewFunctionBuilder(). - WithFunc(func( - ctx context.Context, - mod api.Module, - arg0 uint32, - arg1 uint32, - ) { - buf0, ok0 := mod.Memory().Read(arg0, arg1) - if !ok0 { - panic(errors.New("failed to read bytes from memory")) - } - str0 := string(buf0) - logger.Error(ctx, str0) - }). + WithGoModuleFunction( + api.GoModuleFunc(func( + ctx context.Context, + mod api.Module, + stack []uint64, + ) { + arg0 := stack[0] + arg1 := stack[1] + ptr0 := api.DecodeU32(arg0) + len0 := api.DecodeU32(arg1) + buf0, ok0 := mod.Memory().Read(ptr0, len0) + if !ok0 { + panic(errors.New("failed to read bytes from memory")) + } + str0 := string(buf0) + logger.Error(ctx, str0) + }), + []api.ValueType{ + api.ValueTypeI32, + api.ValueTypeI32, + }, + []api.ValueType{}, + ). Export("error"). Instantiate(ctx) if err0 != nil { @@ -123,28 +163,41 @@ func NewBasicFactory( } _, err1 := wazeroRuntime.NewHostModuleBuilder("arcjet:basic/utils"). NewFunctionBuilder(). - WithFunc(func( - ctx context.Context, - mod api.Module, - arg0 uint32, - arg1 uint32, - arg2 uint32, - ) { - buf0, ok0 := mod.Memory().Read(arg0, arg1) - if !ok0 { - panic(errors.New("failed to read bytes from memory")) - } - str0 := string(buf0) - value1 := utils.Uppercase(ctx, str0) - memory2 := mod.Memory() - realloc2 := mod.ExportedFunction("cabi_realloc") - ptr2, len2, err2 := writeString(ctx, value1, memory2, realloc2) - if err2 != nil { - panic(err2) - } - mod.Memory().WriteUint32Le(arg2+4, uint32(len2)) - mod.Memory().WriteUint32Le(arg2+0, uint32(ptr2)) - }). + WithGoModuleFunction( + api.GoModuleFunc(func( + ctx context.Context, + mod api.Module, + stack []uint64, + ) { + arg0 := stack[0] + arg1 := stack[1] + ptr0 := api.DecodeU32(arg0) + len0 := api.DecodeU32(arg1) + buf0, ok0 := mod.Memory().Read(ptr0, len0) + if !ok0 { + panic(errors.New("failed to read bytes from memory")) + } + str0 := string(buf0) + value1 := utils.Uppercase(ctx, str0) + arg2 := stack[2] + memory2 := mod.Memory() + realloc2 := mod.ExportedFunction("cabi_realloc") + ptr2, len2, err2 := writeString(ctx, value1, memory2, realloc2) + if err2 != nil { + panic(err2) + } + ptr3 := api.DecodeU32(arg2) + mod.Memory().WriteUint32Le(ptr3+4, uint32(len2)) + ptr4 := api.DecodeU32(arg2) + mod.Memory().WriteUint32Le(ptr4+0, uint32(ptr2)) + }), + []api.ValueType{ + api.ValueTypeI32, + api.ValueTypeI32, + api.ValueTypeI32, + }, + []api.ValueType{}, + ). Export("uppercase"). Instantiate(ctx) if err1 != nil { @@ -304,14 +357,14 @@ func (i *BasicInstance) OptionalPrimitive( b bool, ) (bool, bool) { arg0 := b - var variant1_0 uint32 - var variant1_1 uint32 + var variant1_0 uint64 + var variant1_1 uint64 if reflect.ValueOf(arg0).IsZero() { variant1_0 = 0 variant1_1 = 0 } else { variantPayload := arg0 - var value0 uint32 + var value0 uint64 if variantPayload { value0 = 1 } else { @@ -418,7 +471,7 @@ func (i *BasicInstance) OptionalString( s string, ) (string, bool) { arg0 := s - var variant1_0 uint32 + var variant1_0 uint64 var variant1_1 uint64 var variant1_2 uint64 if reflect.ValueOf(arg0).IsZero() { diff --git a/cmd/gravity/tests/cmd/iface-method-returns-string.stderr b/cmd/gravity/tests/cmd/iface-method-returns.stderr similarity index 100% rename from cmd/gravity/tests/cmd/iface-method-returns-string.stderr rename to cmd/gravity/tests/cmd/iface-method-returns.stderr diff --git a/cmd/gravity/tests/cmd/iface-method-returns-string.stdout b/cmd/gravity/tests/cmd/iface-method-returns.stdout similarity index 65% rename from cmd/gravity/tests/cmd/iface-method-returns-string.stdout rename to cmd/gravity/tests/cmd/iface-method-returns.stdout index f0c8a04..e72dd61 100644 --- a/cmd/gravity/tests/cmd/iface-method-returns-string.stdout +++ b/cmd/gravity/tests/cmd/iface-method-returns.stdout @@ -19,6 +19,9 @@ type IExampleRuntime interface { Arch( ctx context.Context, ) string + GetU32( + ctx context.Context, + ) uint32 Puts( ctx context.Context, msg string, @@ -38,53 +41,98 @@ func NewExampleFactory( _, err0 := wazeroRuntime.NewHostModuleBuilder("arcjet:example/runtime"). NewFunctionBuilder(). - WithFunc(func( - ctx context.Context, - mod api.Module, - arg0 uint32, - ) { - value0 := runtime.Os(ctx, ) - memory1 := mod.Memory() - realloc1 := mod.ExportedFunction("cabi_realloc") - ptr1, len1, err1 := writeString(ctx, value0, memory1, realloc1) - if err1 != nil { - panic(err1) - } - mod.Memory().WriteUint32Le(arg0+4, uint32(len1)) - mod.Memory().WriteUint32Le(arg0+0, uint32(ptr1)) - }). + WithGoModuleFunction( + api.GoModuleFunc(func( + ctx context.Context, + mod api.Module, + stack []uint64, + ) { + value0 := runtime.Os(ctx, ) + arg0 := stack[0] + memory1 := mod.Memory() + realloc1 := mod.ExportedFunction("cabi_realloc") + ptr1, len1, err1 := writeString(ctx, value0, memory1, realloc1) + if err1 != nil { + panic(err1) + } + ptr2 := api.DecodeU32(arg0) + mod.Memory().WriteUint32Le(ptr2+4, uint32(len1)) + ptr3 := api.DecodeU32(arg0) + mod.Memory().WriteUint32Le(ptr3+0, uint32(ptr1)) + }), + []api.ValueType{ + api.ValueTypeI32, + }, + []api.ValueType{}, + ). Export("os"). NewFunctionBuilder(). - WithFunc(func( - ctx context.Context, - mod api.Module, - arg0 uint32, - ) { - value0 := runtime.Arch(ctx, ) - memory1 := mod.Memory() - realloc1 := mod.ExportedFunction("cabi_realloc") - ptr1, len1, err1 := writeString(ctx, value0, memory1, realloc1) - if err1 != nil { - panic(err1) - } - mod.Memory().WriteUint32Le(arg0+4, uint32(len1)) - mod.Memory().WriteUint32Le(arg0+0, uint32(ptr1)) - }). + WithGoModuleFunction( + api.GoModuleFunc(func( + ctx context.Context, + mod api.Module, + stack []uint64, + ) { + value0 := runtime.Arch(ctx, ) + arg0 := stack[0] + memory1 := mod.Memory() + realloc1 := mod.ExportedFunction("cabi_realloc") + ptr1, len1, err1 := writeString(ctx, value0, memory1, realloc1) + if err1 != nil { + panic(err1) + } + ptr2 := api.DecodeU32(arg0) + mod.Memory().WriteUint32Le(ptr2+4, uint32(len1)) + ptr3 := api.DecodeU32(arg0) + mod.Memory().WriteUint32Le(ptr3+0, uint32(ptr1)) + }), + []api.ValueType{ + api.ValueTypeI32, + }, + []api.ValueType{}, + ). Export("arch"). NewFunctionBuilder(). - WithFunc(func( - ctx context.Context, - mod api.Module, - arg0 uint32, - arg1 uint32, - ) { - buf0, ok0 := mod.Memory().Read(arg0, arg1) - if !ok0 { - panic(errors.New("failed to read bytes from memory")) - } - str0 := string(buf0) - runtime.Puts(ctx, str0) - }). + WithGoModuleFunction( + api.GoModuleFunc(func( + ctx context.Context, + mod api.Module, + stack []uint64, + ) { + value0 := runtime.GetU32(ctx, ) + result1 := api.EncodeU32(value0) + stack[0] = result1 + }), + []api.ValueType{}, + []api.ValueType{ + api.ValueTypeI32, + }, + ). + Export("get-u32"). + NewFunctionBuilder(). + WithGoModuleFunction( + api.GoModuleFunc(func( + ctx context.Context, + mod api.Module, + stack []uint64, + ) { + arg0 := stack[0] + arg1 := stack[1] + ptr0 := api.DecodeU32(arg0) + len0 := api.DecodeU32(arg1) + buf0, ok0 := mod.Memory().Read(ptr0, len0) + if !ok0 { + panic(errors.New("failed to read bytes from memory")) + } + str0 := string(buf0) + runtime.Puts(ctx, str0) + }), + []api.ValueType{ + api.ValueTypeI32, + api.ValueTypeI32, + }, + []api.ValueType{}, + ). Export("puts"). Instantiate(ctx) if err0 != nil { @@ -225,3 +273,17 @@ func (i *ExampleInstance) Hello( return value8, err8 } +func (i *ExampleInstance) CallGetU32( + ctx context.Context, +) uint32 { + raw0, err0 := i.module.ExportedFunction("call-get-u32").Call(ctx, ) + // The return type doesn't contain an error so we panic if one is encountered + if err0 != nil { + panic(err0) + } + + results0 := raw0[0] + result1 := api.DecodeU32(results0) + return result1 +} + diff --git a/cmd/gravity/tests/cmd/iface-method-returns-string.toml b/cmd/gravity/tests/cmd/iface-method-returns.toml similarity index 64% rename from cmd/gravity/tests/cmd/iface-method-returns-string.toml rename to cmd/gravity/tests/cmd/iface-method-returns.toml index 440e6fc..498ac76 100644 --- a/cmd/gravity/tests/cmd/iface-method-returns-string.toml +++ b/cmd/gravity/tests/cmd/iface-method-returns.toml @@ -1,2 +1,2 @@ bin.name = "gravity" -args = "--world example ../../target/wasm32-unknown-unknown/release/example_iface_method_returns_string.wasm" +args = "--world example ../../target/wasm32-unknown-unknown/release/example_iface_method_returns.wasm" diff --git a/cmd/gravity/tests/cmd/instructions.stdout b/cmd/gravity/tests/cmd/instructions.stdout index 9dbf974..a004079 100644 --- a/cmd/gravity/tests/cmd/instructions.stdout +++ b/cmd/gravity/tests/cmd/instructions.stdout @@ -237,7 +237,7 @@ func (i *InstructionsInstance) EnumInput( val EnumValues, ) { arg0 := val - var enum0 uint32 + var enum0 uint64 switch arg0 { case One: enum0 = 0 diff --git a/examples/generate.go b/examples/generate.go index 47058c8..d249d09 100644 --- a/examples/generate.go +++ b/examples/generate.go @@ -1,9 +1,9 @@ package examples //go:generate cargo build -p example-basic --target wasm32-unknown-unknown --release -//go:generate cargo build -p example-iface-method-returns-string --target wasm32-unknown-unknown --release +//go:generate cargo build -p example-iface-method-returns --target wasm32-unknown-unknown --release //go:generate cargo build -p example-instructions --target wasm32-unknown-unknown --release //go:generate cargo run --bin gravity -- --world basic --output ./basic/basic.go ../target/wasm32-unknown-unknown/release/example_basic.wasm -//go:generate cargo run --bin gravity -- --world example --output ./iface-method-returns-string/example.go ../target/wasm32-unknown-unknown/release/example_iface_method_returns_string.wasm +//go:generate cargo run --bin gravity -- --world example --output ./iface-method-returns/example.go ../target/wasm32-unknown-unknown/release/example_iface_method_returns.wasm //go:generate cargo run --bin gravity -- --world instructions --output ./instructions/bindings.go ../target/wasm32-unknown-unknown/release/example_instructions.wasm diff --git a/examples/iface-method-returns-string/Cargo.toml b/examples/iface-method-returns/Cargo.toml similarity index 75% rename from examples/iface-method-returns-string/Cargo.toml rename to examples/iface-method-returns/Cargo.toml index c53cf72..2c64e31 100644 --- a/examples/iface-method-returns-string/Cargo.toml +++ b/examples/iface-method-returns/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "example-iface-method-returns-string" +name = "example-iface-method-returns" version = "0.0.2" edition = "2024" diff --git a/examples/iface-method-returns-string/example_test.go b/examples/iface-method-returns/example_test.go similarity index 64% rename from examples/iface-method-returns-string/example_test.go rename to examples/iface-method-returns/example_test.go index 6e6a337..cc82b57 100644 --- a/examples/iface-method-returns-string/example_test.go +++ b/examples/iface-method-returns/example_test.go @@ -13,9 +13,10 @@ type Runtime struct { func (Runtime) Os(context.Context) string { return runtime.GOOS } func (Runtime) Arch(context.Context) string { return runtime.GOARCH } +func (Runtime) GetU32(context.Context) uint32 { return 42 } func (r *Runtime) Puts(_ context.Context, msg string) { r.msg = msg } -func TestBasic(t *testing.T) { +func TestHello(t *testing.T) { r := &Runtime{} fac, err := NewExampleFactory(t.Context(), r) if err != nil { @@ -44,3 +45,25 @@ func TestBasic(t *testing.T) { t.Errorf("wanted: %s, but got: %s", wantPutsMsg, r.msg) } } + +func TestCallGetU32(t *testing.T) { + r := &Runtime{} + fac, err := NewExampleFactory(t.Context(), r) + if err != nil { + t.Fatal(err) + } + defer fac.Close(t.Context()) + + ins, err := fac.Instantiate(t.Context()) + if err != nil { + t.Fatal(err) + } + defer ins.Close(t.Context()) + + value := ins.CallGetU32(t.Context()) + + var want uint32 = 42 + if value != want { + t.Errorf("wanted: %d, but got: %d", want, value) + } +} diff --git a/examples/iface-method-returns-string/src/lib.rs b/examples/iface-method-returns/src/lib.rs similarity index 83% rename from examples/iface-method-returns-string/src/lib.rs rename to examples/iface-method-returns/src/lib.rs index b50bf52..c7c5cb1 100644 --- a/examples/iface-method-returns-string/src/lib.rs +++ b/examples/iface-method-returns/src/lib.rs @@ -14,4 +14,8 @@ impl Guest for ExampleWorld { Ok("Hello, world!".into()) } + + fn call_get_u32() -> u32 { + runtime::get_u32() + } } diff --git a/examples/iface-method-returns-string/wit/example.wit b/examples/iface-method-returns/wit/example.wit similarity index 76% rename from examples/iface-method-returns-string/wit/example.wit rename to examples/iface-method-returns/wit/example.wit index 5e82f65..7b26735 100644 --- a/examples/iface-method-returns-string/wit/example.wit +++ b/examples/iface-method-returns/wit/example.wit @@ -4,6 +4,8 @@ interface runtime { os: func() -> string; arch: func() -> string; + get-u32: func() -> u32; + puts: func(msg: string); } @@ -11,4 +13,5 @@ world example { import runtime; export hello: func() -> result; + export call-get-u32: func() -> u32; }