diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 75a603594..316b26e6f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -43,6 +43,7 @@ jobs: with: files: | target/release/cairo-native-test + target/release/scarb-native-test target/release/cairo-native-compile target/release/cairo-native-dump target/release/cairo-native-run diff --git a/Cargo.lock b/Cargo.lock index 62abb7359..bfd612747 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -921,6 +921,8 @@ dependencies = [ "pretty_assertions_sorted", "proptest", "rstest", + "scarb-metadata", + "scarb-ui", "sec1", "serde", "serde_json", @@ -977,6 +979,15 @@ dependencies = [ "thiserror-no-std", ] +[[package]] +name = "camino" +version = "1.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0ec6b951b160caa93cc0c7b209e5a3bff7aae9062213451ac99493cd844c239" +dependencies = [ + "serde", +] + [[package]] name = "cast" version = "0.3.0" @@ -1120,6 +1131,19 @@ dependencies = [ "xdg", ] +[[package]] +name = "console" +version = "0.15.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" +dependencies = [ + "encode_unicode", + "lazy_static", + "libc", + "unicode-width", + "windows-sys 0.52.0", +] + [[package]] name = "const-fnv1a-hash" version = "1.1.0" @@ -1502,6 +1526,12 @@ dependencies = [ "log", ] +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + [[package]] name = "entities" version = "1.0.1" @@ -1916,6 +1946,19 @@ dependencies = [ "serde", ] +[[package]] +name = "indicatif" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3" +dependencies = [ + "console", + "instant", + "number_prefix", + "portable-atomic", + "unicode-width", +] + [[package]] name = "indoc" version = "2.0.5" @@ -2393,6 +2436,12 @@ dependencies = [ "libm", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "once_cell" version = "1.19.0" @@ -2644,6 +2693,12 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "portable-atomic" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" + [[package]] name = "powerfmt" version = "0.2.0" @@ -3048,6 +3103,33 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "scarb-metadata" +version = "1.11.1" +source = "git+https://github.com/software-mansion/scarb.git?rev=v2.6.3#e6f921dfd238e1d96c9087eabe5161e446753907" +dependencies = [ + "camino", + "semver", + "serde", + "serde_json", + "thiserror", +] + +[[package]] +name = "scarb-ui" +version = "0.1.3" +source = "git+https://github.com/software-mansion/scarb.git?rev=v2.6.3#e6f921dfd238e1d96c9087eabe5161e446753907" +dependencies = [ + "anyhow", + "camino", + "clap", + "console", + "indicatif", + "scarb-metadata", + "serde", + "serde_json", +] + [[package]] name = "schemars" version = "0.8.19" @@ -3098,6 +3180,9 @@ name = "semver" version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61697e0a1c7e512e84a621326239844a24d8207b4669b41bc18b32ea5cbf988b" +dependencies = [ + "serde", +] [[package]] name = "serde" @@ -3739,6 +3824,12 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" +[[package]] +name = "unicode-width" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f5e5f3158ecfd4b8ff6fe086db7c8467a2dfdac97fe420f2b7c4aa97af66d6" + [[package]] name = "unicode-xid" version = "0.2.4" diff --git a/Cargo.toml b/Cargo.toml index 35ca3a0f2..211932cab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,11 +24,16 @@ required-features = ["build-cli"] name = "cairo-native-test" required-features = ["build-cli"] +[[bin]] +name = "scarb-native-test" +required-features = ["scarb"] + [features] default = ["build-cli", "with-runtime"] build-cli = ["dep:clap", "dep:tracing-subscriber", "dep:anyhow", "dep:cairo-lang-test-plugin", "dep:cairo-lang-runner", "dep:colored", "dep:cairo-felt", "dep:keccak", "dep:k256", "dep:p256", "dep:sec1"] +scarb = ["build-cli", "dep:scarb-ui", "dep:scarb-metadata", "dep:serde_json"] with-debug-utils = [] with-runtime = ["dep:cairo-native-runtime"] with-serde = ["dep:serde"] @@ -83,7 +88,10 @@ cairo-felt = { version = "0.9.1", optional = true } keccak = { version = "0.1.3", optional = true } k256 = { version = "0.13.3", optional = true } p256 = { version = "0.13.2", optional = true } +scarb-metadata = { git = "https://github.com/software-mansion/scarb.git", rev = "v2.6.3", optional = true } +scarb-ui = { git = "https://github.com/software-mansion/scarb.git", rev = "v2.6.3", optional = true } sec1 = { version = "0.7.3", optional = true } +serde_json = { version = "1.0.117", optional = true } [dev-dependencies] cairo-felt = "0.9.1" diff --git a/README.md b/README.md index 451632536..549e449c0 100644 --- a/README.md +++ b/README.md @@ -802,6 +802,31 @@ cairo-native-test ./cairo-tests/ This will run all the tests (functions marked with the `#[test]` attribute). +# scarb-native-test cli tool + +This tool mimics the `scarb test` [command](https://github.com/software-mansion/scarb/tree/main/extensions/scarb-cairo-test). +You can download it on our [releases](https://github.com/lambdaclass/cairo_native/releases) page. + +```bash +$ scarb-native-test --help +Compiles all packages from a Scarb project matching `packages_filter` and +runs all functions marked with `#[test]`. Exits with 1 if the compilation +or run fails, otherwise 0. + +Usage: scarb-native-test [OPTIONS] + +Options: + -p, --package Packages to run this command on, can be a concrete package name (`foobar`) or a prefix glob (`foo*`) [env: SCARB_PACKAGES_FILTER=] [default: *] + -w, --workspace Run for all packages in the workspace + -f, --filter Run only tests whose name contain FILTER [default: ] + --include-ignored Run ignored and not ignored tests + --ignored Run only ignored tests + --run-mode Run with JIT or AOT (compiled) [default: jit] [possible values: aot, jit] + -O, --opt-level Optimization level, Valid: 0, 1, 2, 3. Values higher than 3 are considered as 3 [default: 0] + -h, --help Print help + -V, --version Print version +``` + ## Debugging Tips ### Useful environment variables diff --git a/src/bin/cairo-native-compile.rs b/src/bin/cairo-native-compile.rs index 68c7457e5..8a107fde3 100644 --- a/src/bin/cairo-native-compile.rs +++ b/src/bin/cairo-native-compile.rs @@ -1,11 +1,11 @@ use anyhow::Context; +use cairo_lang_compiler::project::check_compiler_path; use cairo_native::{ context::NativeContext, module_to_object, object_to_shared_lib, utils::cairo_to_sierra_with_debug_info, }; use clap::{Parser, ValueEnum}; - -use std::path::{Path, PathBuf}; +use std::path::PathBuf; use tracing_subscriber::{EnvFilter, FmtSubscriber}; #[derive(Clone, Debug, ValueEnum)] @@ -80,18 +80,3 @@ fn main() -> anyhow::Result<()> { Ok(()) } - -pub fn check_compiler_path(single_file: bool, path: &Path) -> anyhow::Result<()> { - if path.is_file() { - if !single_file { - anyhow::bail!("The given path is a file, but --single-file was not supplied."); - } - } else if path.is_dir() { - if single_file { - anyhow::bail!("The given path is a directory, but --single-file was supplied."); - } - } else { - anyhow::bail!("The given path does not exist."); - } - Ok(()) -} diff --git a/src/bin/cairo-native-run.rs b/src/bin/cairo-native-run.rs index 4db7c8e67..e307d5c74 100644 --- a/src/bin/cairo-native-run.rs +++ b/src/bin/cairo-native-run.rs @@ -1,10 +1,13 @@ -use anyhow::{bail, Context}; +mod utils; + +use anyhow::Context; use cairo_lang_compiler::{ - db::RootDatabase, diagnostics::DiagnosticsReporter, project::setup_project, + db::RootDatabase, + diagnostics::DiagnosticsReporter, + project::{check_compiler_path, setup_project}, }; use cairo_lang_diagnostics::ToOption; -use cairo_lang_runner::{short_string::as_cairo_short_string, RunResultValue}; -use cairo_lang_sierra::program::{Function, Program}; +use cairo_lang_runner::short_string::as_cairo_short_string; use cairo_lang_sierra_generator::{ db::SierraGenGroup, replace_ids::{DebugReplacer, SierraIdReplacer}, @@ -13,16 +16,13 @@ use cairo_lang_starknet::contract::get_contracts_info; use cairo_native::{ context::NativeContext, debug_info::{DebugInfo, DebugLocations}, - execution_result::ExecutionResult, executor::{AotNativeExecutor, JitNativeExecutor, NativeExecutor}, metadata::gas::{GasMetadata, MetadataComputationConfig}, - values::JitValue, }; use clap::{Parser, ValueEnum}; -use itertools::Itertools; -use starknet_types_core::felt::Felt; use std::path::{Path, PathBuf}; use tracing_subscriber::{EnvFilter, FmtSubscriber}; +use utils::{find_function, result_to_runresult}; #[derive(Clone, Debug, ValueEnum)] enum RunMode { @@ -157,466 +157,3 @@ fn main() -> anyhow::Result<()> { Ok(()) } - -pub fn check_compiler_path(single_file: bool, path: &Path) -> anyhow::Result<()> { - if path.is_file() { - if !single_file { - anyhow::bail!("The given path is a file, but --single-file was not supplied."); - } - } else if path.is_dir() { - if single_file { - anyhow::bail!("The given path is a directory, but --single-file was supplied."); - } - } else { - anyhow::bail!("The given path does not exist."); - } - Ok(()) -} - -pub fn find_function<'a>( - sierra_program: &'a Program, - name_suffix: &str, -) -> anyhow::Result<&'a Function> { - if let Some(x) = sierra_program.funcs.iter().find(|f| { - if let Some(name) = &f.id.debug_name { - name.ends_with(name_suffix) - } else { - false - } - }) { - Ok(x) - } else { - bail!("function {name_suffix} not found") - } -} - -fn result_to_runresult(result: &ExecutionResult) -> anyhow::Result { - let is_success; - let mut felts: Vec = Vec::new(); - - match &result.return_value { - outer_value @ JitValue::Enum { - tag, - value, - debug_name, - } => { - if debug_name - .as_ref() - .expect("missing debug name") - .starts_with("core::panics::PanicResult::") - { - is_success = *tag == 0; - - if !is_success { - match &**value { - JitValue::Struct { fields, .. } => { - for field in fields { - let felt = jitvalue_to_felt(field); - felts.extend(felt); - } - } - _ => bail!("unsuported return value in cairo-native"), - } - } else { - felts.extend(jitvalue_to_felt(value)); - } - } else { - is_success = true; - felts.extend(jitvalue_to_felt(outer_value)); - } - } - x => { - is_success = true; - felts.extend(jitvalue_to_felt(x)); - } - } - - let return_values = felts - .into_iter() - .map(|x| x.to_bigint().into()) - .collect_vec(); - - Ok(match is_success { - true => RunResultValue::Success(return_values), - false => RunResultValue::Panic(return_values), - }) -} - -fn jitvalue_to_felt(value: &JitValue) -> Vec { - match value { - JitValue::Felt252(felt) => vec![*felt], - JitValue::BoundedInt { value, .. } => vec![*value], - JitValue::Bytes31(bytes) => vec![Felt::from_bytes_le_slice(bytes)], - JitValue::Array(fields) | JitValue::Struct { fields, .. } => { - fields.iter().flat_map(jitvalue_to_felt).collect() - } - JitValue::Enum { - value, - tag, - debug_name, - } => { - if let Some(debug_name) = debug_name { - if debug_name == "core::bool" { - vec![(*tag == 1).into()] - } else { - let mut felts = vec![(*tag).into()]; - felts.extend(jitvalue_to_felt(value)); - felts - } - } else { - todo!() - } - } - JitValue::Uint8(x) => vec![(*x).into()], - JitValue::Uint16(x) => vec![(*x).into()], - JitValue::Uint32(x) => vec![(*x).into()], - JitValue::Uint64(x) => vec![(*x).into()], - JitValue::Uint128(x) => vec![(*x).into()], - JitValue::Sint8(x) => vec![(*x).into()], - JitValue::Sint16(x) => vec![(*x).into()], - JitValue::Sint32(x) => vec![(*x).into()], - JitValue::Sint64(x) => vec![(*x).into()], - JitValue::Sint128(x) => vec![(*x).into()], - JitValue::Null => vec![0.into()], - JitValue::EcPoint(_, _) - | JitValue::EcState(_, _, _, _) - | JitValue::Secp256K1Point { .. } - | JitValue::Secp256R1Point { .. } - | JitValue::Felt252Dict { .. } => todo!(), - } -} - -#[cfg(test)] -mod tests { - use super::*; - use cairo_felt::Felt252; - use cairo_lang_sierra::ProgramParser; - - #[test] - fn test_check_compiler_path() { - // Define file, folder, and invalid paths for testing - let file_path = Path::new("src/bin/cairo-native-run.rs"); - let folder_path = Path::new("src/bin"); - let invalid_path = Path::new("src/non-existing-file.rs"); - - // Test when single_file is true and the path is a file - assert!(check_compiler_path(true, file_path).is_ok()); - - // Test when single_file is false and the path is a file - assert!(check_compiler_path(false, file_path).is_err()); - - // Test when single_file is true and the path is a folder - assert!(check_compiler_path(true, folder_path).is_err()); - - // Test when single_file is false and the path is a folder - assert!(check_compiler_path(false, folder_path).is_ok()); - - // Test when single_file is true and the path does not exist - assert!(check_compiler_path(true, invalid_path).is_err()); - - // Test when single_file is false and the path does not exist - assert!(check_compiler_path(false, invalid_path).is_err()); - } - - #[test] - fn test_find_function() { - // Parse a simple program containing a function named "Func2" - let program = ProgramParser::new().parse("Func2@6() -> ();").unwrap(); - - // Assert that the function "Func2" is found and returned correctly - assert_eq!( - find_function(&program, "Func2").unwrap(), - program.funcs.first().unwrap() - ); - - // Assert that an error is returned when trying to find a non-existing function "Func3" - assert!(find_function(&program, "Func3").is_err()); - - // Assert that an error is returned when trying to find a function in an empty program - assert!(find_function(&ProgramParser::new().parse("").unwrap(), "Func2").is_err()); - } - - #[test] - fn test_result_to_runresult_enum_nonpanic() { - // Tests the conversion of a non-panic enum result to a `RunResultValue::Success`. - assert_eq!( - result_to_runresult(&ExecutionResult { - remaining_gas: None, - return_value: JitValue::Enum { - tag: 34, - value: JitValue::Array(vec![ - JitValue::Felt252(42.into()), - JitValue::Uint8(100), - JitValue::Uint128(1000), - ]) - .into(), - debug_name: Some("debug_name".into()), - }, - builtin_stats: Default::default(), - }) - .unwrap(), - RunResultValue::Success(vec![ - Felt252::from(34), - Felt252::from(42), - Felt252::from(100), - Felt252::from(1000) - ]) - ); - } - - #[test] - fn test_result_to_runresult_success() { - // Tests the conversion of a success enum result to a `RunResultValue::Success`. - assert_eq!( - result_to_runresult(&ExecutionResult { - remaining_gas: None, - return_value: JitValue::Enum { - tag: 0, - value: JitValue::Uint64(24).into(), - debug_name: Some("core::panics::PanicResult::Test".into()), - }, - builtin_stats: Default::default(), - }) - .unwrap(), - RunResultValue::Success(vec![Felt252::from(24)]) - ); - } - - #[test] - #[should_panic(expected = "unsuported return value in cairo-native")] - fn test_result_to_runresult_panic() { - // Tests the conversion with unsuported return value. - let _ = result_to_runresult(&ExecutionResult { - remaining_gas: None, - return_value: JitValue::Enum { - tag: 10, - value: JitValue::Uint64(24).into(), - debug_name: Some("core::panics::PanicResult::Test".into()), - }, - builtin_stats: Default::default(), - }) - .unwrap(); - } - - #[test] - #[should_panic(expected = "missing debug name")] - fn test_result_to_runresult_missing_debug_name() { - // Tests the conversion with no debug name. - let _ = result_to_runresult(&ExecutionResult { - remaining_gas: None, - return_value: JitValue::Enum { - tag: 10, - value: JitValue::Uint64(24).into(), - debug_name: None, - }, - builtin_stats: Default::default(), - }) - .unwrap(); - } - - #[test] - fn test_result_to_runresult_return() { - // Tests the conversion of a panic enum result with non-zero tag to a `RunResultValue::Panic`. - assert_eq!( - result_to_runresult(&ExecutionResult { - remaining_gas: None, - return_value: JitValue::Enum { - tag: 10, - value: JitValue::Struct { - fields: vec![ - JitValue::Felt252(42.into()), - JitValue::Uint8(100), - JitValue::Uint128(1000), - ], - debug_name: Some("debug_name".into()), - } - .into(), - debug_name: Some("core::panics::PanicResult::Test".into()), - }, - builtin_stats: Default::default(), - }) - .unwrap(), - RunResultValue::Panic(vec![ - Felt252::from(42), - Felt252::from(100), - Felt252::from(1000) - ]) - ); - } - - #[test] - fn test_result_to_runresult_non_enum() { - // Tests the conversion of a non-enum result to a `RunResultValue::Success`. - assert_eq!( - result_to_runresult(&ExecutionResult { - remaining_gas: None, - return_value: JitValue::Uint8(10), - builtin_stats: Default::default(), - }) - .unwrap(), - RunResultValue::Success(vec![Felt252::from(10)]) - ); - } - - #[test] - fn test_jitvalue_to_felt_felt252() { - let felt_value: Felt = 42.into(); - - assert_eq!( - jitvalue_to_felt(&JitValue::Felt252(felt_value)), - vec![felt_value] - ); - } - - #[test] - fn test_jitvalue_to_felt_array() { - assert_eq!( - jitvalue_to_felt(&JitValue::Array(vec![ - JitValue::Felt252(42.into()), - JitValue::Uint8(100), - JitValue::Uint128(1000), - ])), - vec![Felt::from(42), Felt::from(100), Felt::from(1000)] - ); - } - - #[test] - fn test_jitvalue_to_felt_struct() { - assert_eq!( - jitvalue_to_felt(&JitValue::Struct { - fields: vec![ - JitValue::Felt252(42.into()), - JitValue::Uint8(100), - JitValue::Uint128(1000) - ], - debug_name: Some("debug_name".into()) - }), - vec![Felt::from(42), Felt::from(100), Felt::from(1000)] - ); - } - - #[test] - fn test_jitvalue_to_felt_enum() { - // With debug name - assert_eq!( - jitvalue_to_felt(&JitValue::Enum { - tag: 34, - value: JitValue::Array(vec![ - JitValue::Felt252(42.into()), - JitValue::Uint8(100), - JitValue::Uint128(1000), - ]) - .into(), - debug_name: Some("debug_name".into()) - }), - vec![ - Felt::from(34), - Felt::from(42), - Felt::from(100), - Felt::from(1000) - ] - ); - - // With core::bool debug name and tag 1 - assert_eq!( - jitvalue_to_felt(&JitValue::Enum { - tag: 1, - value: JitValue::Uint128(1000).into(), - debug_name: Some("core::bool".into()) - }), - vec![Felt::ONE] - ); - - // With core::bool debug name and tag not 1 - assert_eq!( - jitvalue_to_felt(&JitValue::Enum { - tag: 10, - value: JitValue::Uint128(1000).into(), - debug_name: Some("core::bool".into()) - }), - vec![Felt::ZERO] - ); - } - - #[test] - fn test_jitvalue_to_felt_u8() { - assert_eq!(jitvalue_to_felt(&JitValue::Uint8(10)), vec![Felt::from(10)]); - } - - #[test] - fn test_jitvalue_to_felt_u16() { - assert_eq!( - jitvalue_to_felt(&JitValue::Uint16(100)), - vec![Felt::from(100)] - ); - } - - #[test] - fn test_jitvalue_to_felt_u32() { - assert_eq!( - jitvalue_to_felt(&JitValue::Uint32(1000)), - vec![Felt::from(1000)] - ); - } - - #[test] - fn test_jitvalue_to_felt_u64() { - assert_eq!( - jitvalue_to_felt(&JitValue::Uint64(10000)), - vec![Felt::from(10000)] - ); - } - - #[test] - fn test_jitvalue_to_felt_u128() { - assert_eq!( - jitvalue_to_felt(&JitValue::Uint128(100000)), - vec![Felt::from(100000)] - ); - } - - #[test] - fn test_jitvalue_to_felt_sint8() { - assert_eq!( - jitvalue_to_felt(&JitValue::Sint8(-10)), - vec![Felt::from(-10)] - ); - } - - #[test] - fn test_jitvalue_to_felt_sint16() { - assert_eq!( - jitvalue_to_felt(&JitValue::Sint16(-100)), - vec![Felt::from(-100)] - ); - } - - #[test] - fn test_jitvalue_to_felt_sint32() { - assert_eq!( - jitvalue_to_felt(&JitValue::Sint32(-1000)), - vec![Felt::from(-1000)] - ); - } - - #[test] - fn test_jitvalue_to_felt_sint64() { - assert_eq!( - jitvalue_to_felt(&JitValue::Sint64(-10000)), - vec![Felt::from(-10000)] - ); - } - - #[test] - fn test_jitvalue_to_felt_sint128() { - assert_eq!( - jitvalue_to_felt(&JitValue::Sint128(-100000)), - vec![Felt::from(-100000)] - ); - } - - #[test] - fn test_jitvalue_to_felt_null() { - assert_eq!(jitvalue_to_felt(&JitValue::Null), vec![Felt::ZERO]); - } -} diff --git a/src/bin/cairo-native-test.rs b/src/bin/cairo-native-test.rs index 3635e5263..ba42532f8 100644 --- a/src/bin/cairo-native-test.rs +++ b/src/bin/cairo-native-test.rs @@ -1,50 +1,22 @@ -use anyhow::{bail, Context}; -use cairo_felt::Felt252; +mod utils; + +use anyhow::bail; use cairo_lang_compiler::{ - db::RootDatabase, diagnostics::DiagnosticsReporter, project::setup_project, + db::RootDatabase, + diagnostics::DiagnosticsReporter, + project::{check_compiler_path, setup_project}, }; use cairo_lang_filesystem::cfg::{Cfg, CfgSet}; -use cairo_lang_runner::{casm_run::format_next_item, RunResultValue}; -use cairo_lang_sierra::{ - extensions::gas::CostTokenType, - ids::FunctionId, - program::{Function, Program}, -}; -use cairo_lang_starknet::{contract::ContractInfo, starknet_plugin_suite}; -use cairo_lang_test_plugin::{ - compile_test_prepared_db, - test_config::{PanicExpectation, TestExpectation}, - test_plugin_suite, TestCompilation, TestConfig, -}; -use cairo_lang_utils::{casts::IntoOrPanic, ordered_hash_map::OrderedHashMap}; -use cairo_native::{ - context::NativeContext, - execution_result::ExecutionResult, - executor::{AotNativeExecutor, JitNativeExecutor, NativeExecutor}, - metadata::gas::{GasMetadata, MetadataComputationConfig}, - starknet::{Secp256k1Point, Secp256r1Point, StarknetSyscallHandler, SyscallResult, U256}, - values::JitValue, -}; -use clap::{Parser, ValueEnum}; +use cairo_lang_starknet::starknet_plugin_suite; +use cairo_lang_test_plugin::{compile_test_prepared_db, test_plugin_suite}; +use clap::Parser; use colored::Colorize; -use itertools::Itertools; -use k256::elliptic_curve::sec1::ToEncodedPoint; -use k256::elliptic_curve::{generic_array::GenericArray, sec1::FromEncodedPoint}; -use num_traits::ToPrimitive; -use sec1::point::Coordinates; -use starknet_types_core::felt::Felt; -use std::{ - iter::once, - path::{Path, PathBuf}, - vec::IntoIter, -}; +use std::path::{Path, PathBuf}; use tracing_subscriber::{EnvFilter, FmtSubscriber}; - -#[derive(Clone, Debug, ValueEnum)] -enum RunMode { - Aot, - Jit, -} +use utils::{ + test::{display_tests_summary, filter_test_cases, run_tests}, + RunArgs, RunMode, +}; /// Compiles a Cairo project and runs all the functions marked as `#[test]`. /// Exits with 1 if the compilation or run fails, otherwise 0. @@ -130,1362 +102,26 @@ fn main() -> anyhow::Result<()> { args.filter.clone(), ); - let TestsSummary { - passed, - failed, - ignored, - failed_run_results, - } = run_tests( + let summary = run_tests( compiled.named_tests, compiled.sierra_program, compiled.function_set_costs, - compiled.contracts_info, - &args, + RunArgs { + run_mode: args.run_mode.clone(), + opt_level: args.opt_level, + }, )?; - if failed.is_empty() { - println!( - "test result: {}. {} passed; {} failed; {} ignored; {filtered_out} filtered out;", - "ok".bright_green(), - passed.len(), - failed.len(), - ignored.len() - ); - } else { - println!("failures:"); - for (failure, run_result) in failed.iter().zip_eq(failed_run_results) { - print!(" {failure} - "); - match run_result { - RunResultValue::Success(_) => { - println!("expected panic but finished successfully."); - } - RunResultValue::Panic(values) => { - println!("{}", format_for_panic(values.into_iter())); - } - } - } - println!(); + display_tests_summary(&summary, filtered_out); + if !summary.failed.is_empty() { bail!( "test result: {}. {} passed; {} failed; {} ignored", "FAILED".bright_red(), - passed.len(), - failed.len(), - ignored.len() + summary.passed.len(), + summary.failed.len(), + summary.ignored.len() ); } Ok(()) } - -pub fn check_compiler_path(single_file: bool, path: &Path) -> anyhow::Result<()> { - if path.is_file() { - if !single_file { - anyhow::bail!("The given path is a file, but --single-file was not supplied."); - } - } else if path.is_dir() { - if single_file { - anyhow::bail!("The given path is a directory, but --single-file was supplied."); - } - } else { - anyhow::bail!("The given path does not exist."); - } - Ok(()) -} - -/// Formats the given felts as a panic string. -fn format_for_panic(mut felts: IntoIter) -> String { - let mut items = Vec::new(); - while let Some(item) = format_next_item(&mut felts) { - items.push(item.quote_if_string()); - } - let panic_values_string = if let [item] = &items[..] { - item.clone() - } else { - format!("({})", items.join(", ")) - }; - format!("Panicked with {panic_values_string}.") -} - -/// Filter compiled test cases with user provided arguments. -/// -/// # Arguments -/// * `compiled` - Compiled test cases with metadata. -/// * `include_ignored` - Include ignored tests as well. -/// * `ignored` - Run ignored tests only.l -/// * `filter` - Include only tests containing the filter string. -/// # Returns -/// * (`TestCompilation`, `usize`) - The filtered test cases and the number of filtered out cases. -pub fn filter_test_cases( - compiled: TestCompilation, - include_ignored: bool, - ignored: bool, - filter: String, -) -> (TestCompilation, usize) { - let total_tests_count = compiled.named_tests.len(); - let named_tests = compiled - .named_tests - .into_iter() - .filter(|(name, _)| name.contains(&filter)); - - let named_tests = if include_ignored { - // enable the ignored tests - named_tests - .into_iter() - .map(|(name, mut test)| { - test.ignored = false; - (name, test) - }) - .collect_vec() - } else if ignored { - // filter not ignored tests and enable the remaining ones - named_tests - .into_iter() - .map(|(name, mut test)| { - test.ignored = !test.ignored; - (name, test) - }) - .filter(|(_, test)| !test.ignored) - .collect_vec() - } else { - named_tests.collect_vec() - }; - - let filtered_out = total_tests_count - named_tests.len(); - let tests = TestCompilation { - named_tests, - ..compiled - }; - (tests, filtered_out) -} - -pub fn find_function<'a>( - sierra_program: &'a Program, - name_suffix: &str, -) -> anyhow::Result<&'a Function> { - if let Some(x) = sierra_program.funcs.iter().find(|f| { - if let Some(name) = &f.id.debug_name { - name.ends_with(name_suffix) - } else { - false - } - }) { - Ok(x) - } else { - bail!("test function not found") - } -} - -/// The status of a ran test. -enum TestStatus { - Success, - Fail(RunResultValue), -} - -/// The result of a ran test. -struct TestResult { - /// The status of the run. - status: TestStatus, - /// The gas usage of the run if relevant. - gas_usage: Option, -} - -/// Summary data of the ran tests. -pub struct TestsSummary { - passed: Vec, - failed: Vec, - ignored: Vec, - failed_run_results: Vec, -} - -fn result_to_runresult(result: &ExecutionResult) -> anyhow::Result { - let is_success; - let mut felts: Vec = Vec::new(); - - match &result.return_value { - JitValue::Enum { tag, value, .. } => { - is_success = *tag == 0; - - if !is_success { - match &**value { - JitValue::Struct { fields, .. } => { - for field in fields { - let felt = jitvalue_to_felt(field); - felts.extend(felt); - } - } - _ => bail!( - "unsuported return value in cairo-native (inside enum): {:#?}", - value - ), - } - } - } - value => { - is_success = true; - let felt = jitvalue_to_felt(value); - felts.extend(felt); - } - } - - let return_values = felts - .into_iter() - .map(|x| x.to_bigint().into()) - .collect_vec(); - - Ok(match is_success { - true => RunResultValue::Success(return_values), - false => RunResultValue::Panic(return_values), - }) -} - -fn jitvalue_to_felt(value: &JitValue) -> Vec { - let mut felts = Vec::new(); - match value { - JitValue::Felt252(felt) => vec![*felt], - JitValue::BoundedInt { value, .. } => vec![*value], - JitValue::Bytes31(_) => todo!(), - JitValue::Array(values) => { - for value in values { - let felt = jitvalue_to_felt(value); - felts.extend(felt); - } - felts - } - JitValue::Struct { fields, .. } => { - for field in fields { - let felt = jitvalue_to_felt(field); - felts.extend(felt); - } - felts - } - JitValue::Enum { .. } => todo!(), - JitValue::Felt252Dict { value, .. } => { - for (key, value) in value { - felts.push(*key); - let felt = jitvalue_to_felt(value); - felts.extend(felt); - } - - felts - } - JitValue::Uint8(x) => vec![(*x).into()], - JitValue::Uint16(x) => vec![(*x).into()], - JitValue::Uint32(x) => vec![(*x).into()], - JitValue::Uint64(x) => vec![(*x).into()], - JitValue::Uint128(x) => vec![(*x).into()], - JitValue::Sint8(x) => vec![(*x).into()], - JitValue::Sint16(x) => vec![(*x).into()], - JitValue::Sint32(x) => vec![(*x).into()], - JitValue::Sint64(x) => vec![(*x).into()], - JitValue::Sint128(x) => vec![(*x).into()], - JitValue::EcPoint(_, _) => todo!(), - JitValue::EcState(_, _, _, _) => todo!(), - JitValue::Secp256K1Point { .. } => todo!(), - JitValue::Secp256R1Point { .. } => todo!(), - JitValue::Null => vec![0.into()], - } -} - -/// Runs the tests and process the results for a summary. -fn run_tests( - named_tests: Vec<(String, TestConfig)>, - sierra_program: Program, - function_set_costs: OrderedHashMap>, - _contracts_info: OrderedHashMap, - args: &Args, -) -> anyhow::Result { - let native_context = NativeContext::new(); - - // Compile the sierra program into a MLIR module. - let native_module = native_context - .compile_with_metadata( - &sierra_program, - MetadataComputationConfig { - function_set_costs: function_set_costs.clone(), - linear_ap_change_solver: true, - linear_gas_solver: true, - }, - ) - .unwrap(); - - let native_executor: NativeExecutor = match args.run_mode { - RunMode::Aot => { - AotNativeExecutor::from_native_module(native_module, args.opt_level.into()).into() - } - RunMode::Jit => { - JitNativeExecutor::from_native_module(native_module, args.opt_level.into()).into() - } - }; - - let gas_metadata = GasMetadata::new( - &sierra_program, - Some(MetadataComputationConfig { - function_set_costs, - linear_ap_change_solver: true, - linear_gas_solver: true, - }), - ) - .unwrap(); - - println!("running {} tests", named_tests.len()); - let mut wrapped_summary = Ok(TestsSummary { - passed: vec![], - failed: vec![], - ignored: vec![], - failed_run_results: vec![], - }); - named_tests - .into_iter() - .map( - |(name, test)| -> anyhow::Result<(String, Option)> { - if test.ignored { - return Ok((name, None)); - } - tracing::trace!("running test {name:?}"); - - let func = find_function(&sierra_program, name.as_str())?; - - let initial_gas = test.available_gas.map(|x| x.try_into().unwrap()); - - let result = native_executor - .invoke_dynamic_with_syscall_handler( - &func.id, - &[], - initial_gas, - TestSyscallHandler, - ) - .with_context(|| format!("Failed to run the function `{}`.", name.as_str()))?; - - let run_result = result_to_runresult(&result)?; - Ok(( - name, - Some(TestResult { - status: match &run_result { - RunResultValue::Success(_) => match test.expectation { - TestExpectation::Success => TestStatus::Success, - TestExpectation::Panics(_) => TestStatus::Fail(run_result), - }, - RunResultValue::Panic(value) => match test.expectation { - TestExpectation::Success => TestStatus::Fail(run_result), - TestExpectation::Panics(panic_expectation) => { - match panic_expectation { - PanicExpectation::Exact(expected) if value != &expected => { - TestStatus::Fail(run_result) - } - _ => TestStatus::Success, - } - } - }, - }, - gas_usage: test - .available_gas - .zip(result.remaining_gas) - .map(|(before, after)| { - before.into_or_panic::() - after.to_i64().unwrap() - }) - .or_else(|| { - gas_metadata - .initial_required_gas(&func.id) - .map(|gas| gas.try_into().unwrap()) - }), - }), - )) - }, - ) - .for_each(|r| { - let (name, status) = match r { - Ok((name, status)) => (name, status), - Err(err) => { - wrapped_summary = Err(err); - return; - } - }; - let summary = wrapped_summary.as_mut().unwrap(); - let (res_type, status_str, gas_usage) = match status { - Some(TestResult { - status: TestStatus::Success, - gas_usage, - }) => (&mut summary.passed, "ok".bright_green(), gas_usage), - Some(TestResult { - status: TestStatus::Fail(run_result), - gas_usage, - }) => { - summary.failed_run_results.push(run_result); - (&mut summary.failed, "fail".bright_red(), gas_usage) - } - None => (&mut summary.ignored, "ignored".bright_yellow(), None), - }; - if let Some(gas_usage) = gas_usage { - println!("test {name} ... {status_str} (gas usage est.: {gas_usage})"); - } else { - println!("test {name} ... {status_str}"); - } - res_type.push(name); - }); - wrapped_summary -} - -pub struct TestSyscallHandler; - -impl StarknetSyscallHandler for TestSyscallHandler { - fn get_block_hash( - &mut self, - _block_number: u64, - _remaining_gas: &mut u128, - ) -> SyscallResult { - unimplemented!() - } - - fn get_execution_info( - &mut self, - _remaining_gas: &mut u128, - ) -> SyscallResult { - unimplemented!() - } - - fn get_execution_info_v2( - &mut self, - _remaining_gas: &mut u128, - ) -> SyscallResult { - unimplemented!() - } - - fn deploy( - &mut self, - _class_hash: Felt, - _contract_address_salt: Felt, - _calldata: &[Felt], - _deploy_from_zero: bool, - _remaining_gas: &mut u128, - ) -> SyscallResult<(Felt, Vec)> { - unimplemented!() - } - - fn replace_class(&mut self, _class_hash: Felt, _remaining_gas: &mut u128) -> SyscallResult<()> { - unimplemented!() - } - - fn library_call( - &mut self, - _class_hash: Felt, - _function_selector: Felt, - _calldata: &[Felt], - _remaining_gas: &mut u128, - ) -> SyscallResult> { - unimplemented!() - } - - fn call_contract( - &mut self, - _address: Felt, - _entry_point_selector: Felt, - _calldata: &[Felt], - _remaining_gas: &mut u128, - ) -> SyscallResult> { - unimplemented!() - } - - fn storage_read( - &mut self, - _address_domain: u32, - _address: Felt, - _remaining_gas: &mut u128, - ) -> SyscallResult { - unimplemented!() - } - - fn storage_write( - &mut self, - _address_domain: u32, - _address: Felt, - _value: Felt, - _remaining_gas: &mut u128, - ) -> SyscallResult<()> { - unimplemented!() - } - - fn emit_event( - &mut self, - _keys: &[Felt], - _data: &[Felt], - _remaining_gas: &mut u128, - ) -> SyscallResult<()> { - unimplemented!() - } - - fn send_message_to_l1( - &mut self, - _to_address: Felt, - _payload: &[Felt], - _remaining_gas: &mut u128, - ) -> SyscallResult<()> { - unimplemented!() - } - - fn keccak(&mut self, input: &[u64], gas: &mut u128) -> SyscallResult { - let length = input.len(); - - if length % 17 != 0 { - let error_msg = b"Invalid keccak input size"; - let felt_error = Felt::from_bytes_be_slice(error_msg); - return Err(vec![felt_error]); - } - - let n_chunks = length / 17; - let mut state = [0u64; 25]; - - for i in 0..n_chunks { - if *gas < KECCAK_ROUND_COST { - let error_msg = b"Syscall out of gas"; - let felt_error = Felt::from_bytes_be_slice(error_msg); - return Err(vec![felt_error]); - } - const KECCAK_ROUND_COST: u128 = 180000; - *gas -= KECCAK_ROUND_COST; - let chunk = &input[i * 17..(i + 1) * 17]; //(request.input_start + i * 17)?; - for (i, val) in chunk.iter().enumerate() { - state[i] ^= val; - } - keccak::f1600(&mut state) - } - - // state[0] and state[1] conform the hash_high (u128) - // state[2] and state[3] conform the hash_low (u128) - SyscallResult::Ok(U256 { - lo: state[2] as u128 | ((state[3] as u128) << 64), - hi: state[0] as u128 | ((state[1] as u128) << 64), - }) - } - - fn secp256k1_new( - &mut self, - x: U256, - y: U256, - _remaining_gas: &mut u128, - ) -> SyscallResult> { - // The following unwraps should be unreachable because the iterator we provide has the - // expected number of bytes. - let point = k256::ProjectivePoint::from_encoded_point( - &k256::EncodedPoint::from_affine_coordinates( - &GenericArray::from_exact_iter( - x.hi.to_be_bytes().into_iter().chain(x.lo.to_be_bytes()), - ) - .unwrap(), - &GenericArray::from_exact_iter( - y.hi.to_be_bytes().into_iter().chain(y.lo.to_be_bytes()), - ) - .unwrap(), - false, - ), - ); - - if bool::from(point.is_some()) { - Ok(Some(Secp256k1Point { x, y })) - } else { - Ok(None) - } - } - - fn secp256k1_add( - &mut self, - p0: Secp256k1Point, - p1: Secp256k1Point, - _remaining_gas: &mut u128, - ) -> SyscallResult { - // The inner unwraps should be unreachable because the iterator we provide has the expected - // number of bytes. The outer unwraps depend on the felt values, which should be valid since - // they'll be provided by secp256 syscalls. - let p0 = k256::ProjectivePoint::from_encoded_point( - &k256::EncodedPoint::from_affine_coordinates( - &GenericArray::from_exact_iter( - p0.x.hi - .to_be_bytes() - .into_iter() - .chain(p0.x.lo.to_be_bytes()), - ) - .unwrap(), - &GenericArray::from_exact_iter( - p0.y.hi - .to_be_bytes() - .into_iter() - .chain(p0.y.lo.to_be_bytes()), - ) - .unwrap(), - false, - ), - ) - .unwrap(); - let p1 = k256::ProjectivePoint::from_encoded_point( - &k256::EncodedPoint::from_affine_coordinates( - &GenericArray::from_exact_iter( - p1.x.hi - .to_be_bytes() - .into_iter() - .chain(p1.x.lo.to_be_bytes()), - ) - .unwrap(), - &GenericArray::from_exact_iter( - p1.y.hi - .to_be_bytes() - .into_iter() - .chain(p1.y.lo.to_be_bytes()), - ) - .unwrap(), - false, - ), - ) - .unwrap(); - - let p = p0 + p1; - - let p = p.to_encoded_point(false); - let (x, y) = match p.coordinates() { - Coordinates::Uncompressed { x, y } => (x, y), - _ => { - // This should be unreachable because we explicitly asked for the uncompressed - // encoding. - unreachable!() - } - }; - - // The following two unwraps should be safe because the array always has 32 bytes. The other - // four are definitely safe because the slicing guarantees its length to be the right one. - let x: [u8; 32] = x.as_slice().try_into().unwrap(); - let y: [u8; 32] = y.as_slice().try_into().unwrap(); - Ok(Secp256k1Point { - x: U256 { - hi: u128::from_be_bytes(x[0..16].try_into().unwrap()), - lo: u128::from_be_bytes(x[16..32].try_into().unwrap()), - }, - y: U256 { - hi: u128::from_be_bytes(y[0..16].try_into().unwrap()), - lo: u128::from_be_bytes(y[16..32].try_into().unwrap()), - }, - }) - } - - fn secp256k1_mul( - &mut self, - p: Secp256k1Point, - m: U256, - _remaining_gas: &mut u128, - ) -> SyscallResult { - // The inner unwrap should be unreachable because the iterator we provide has the expected - // number of bytes. The outer unwrap depends on the felt values, which should be valid since - // they'll be provided by secp256 syscalls. - let p = k256::ProjectivePoint::from_encoded_point( - &k256::EncodedPoint::from_affine_coordinates( - &GenericArray::from_exact_iter( - p.x.hi.to_be_bytes().into_iter().chain(p.x.lo.to_be_bytes()), - ) - .unwrap(), - &GenericArray::from_exact_iter( - p.y.hi.to_be_bytes().into_iter().chain(p.y.lo.to_be_bytes()), - ) - .unwrap(), - false, - ), - ) - .unwrap(); - let m: k256::Scalar = k256::elliptic_curve::ScalarPrimitive::from_slice(&{ - let mut buf = [0u8; 32]; - buf[0..16].copy_from_slice(&m.hi.to_be_bytes()); - buf[16..32].copy_from_slice(&m.lo.to_be_bytes()); - buf - }) - .map_err(|_| { - vec![Felt::from_bytes_be( - b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0invalid scalar", - )] - })? - .into(); - - let p = p * m; - - let p = p.to_encoded_point(false); - let (x, y) = match p.coordinates() { - Coordinates::Uncompressed { x, y } => (x, y), - _ => { - // This should be unreachable because we explicitly asked for the uncompressed - // encoding. - unreachable!() - } - }; - - // The following two unwraps should be safe because the array always has 32 bytes. The other - // four are definitely safe because the slicing guarantees its length to be the right one. - let x: [u8; 32] = x.as_slice().try_into().unwrap(); - let y: [u8; 32] = y.as_slice().try_into().unwrap(); - Ok(Secp256k1Point { - x: U256 { - hi: u128::from_be_bytes(x[0..16].try_into().unwrap()), - lo: u128::from_be_bytes(x[16..32].try_into().unwrap()), - }, - y: U256 { - hi: u128::from_be_bytes(y[0..16].try_into().unwrap()), - lo: u128::from_be_bytes(y[16..32].try_into().unwrap()), - }, - }) - } - - fn secp256k1_get_point_from_x( - &mut self, - x: U256, - y_parity: bool, - _remaining_gas: &mut u128, - ) -> SyscallResult> { - // The inner unwrap should be unreachable because the iterator we provide has the expected - // number of bytes. The outer unwrap depends on the encoding format, which should be valid - // since it's hardcoded.. - let point = k256::ProjectivePoint::from_encoded_point( - &k256::EncodedPoint::from_bytes( - k256::CompressedPoint::from_exact_iter( - once(0x02 | y_parity as u8) - .chain(x.hi.to_be_bytes()) - .chain(x.lo.to_be_bytes()), - ) - .unwrap(), - ) - .unwrap(), - ); - - if bool::from(point.is_some()) { - // This unwrap has already been checked in the `if` expression's condition. - let p = point.unwrap(); - - let p = p.to_encoded_point(false); - let y = match p.coordinates() { - Coordinates::Uncompressed { y, .. } => y, - _ => { - // This should be unreachable because we explicitly asked for the uncompressed - // encoding. - unreachable!() - } - }; - - // The following unwrap should be safe because the array always has 32 bytes. The other - // two are definitely safe because the slicing guarantees its length to be the right - // one. - let y: [u8; 32] = y.as_slice().try_into().unwrap(); - Ok(Some(Secp256k1Point { - x, - y: U256 { - hi: u128::from_be_bytes(y[0..16].try_into().unwrap()), - lo: u128::from_be_bytes(y[16..32].try_into().unwrap()), - }, - })) - } else { - Ok(None) - } - } - - fn secp256k1_get_xy( - &mut self, - p: Secp256k1Point, - _remaining_gas: &mut u128, - ) -> SyscallResult<(U256, U256)> { - Ok((p.x, p.y)) - } - - fn secp256r1_new( - &mut self, - x: U256, - y: U256, - _remaining_gas: &mut u128, - ) -> SyscallResult> { - // The following unwraps should be unreachable because the iterator we provide has the - // expected number of bytes. - let point = p256::ProjectivePoint::from_encoded_point( - &k256::EncodedPoint::from_affine_coordinates( - &GenericArray::from_exact_iter( - x.hi.to_be_bytes().into_iter().chain(x.lo.to_be_bytes()), - ) - .unwrap(), - &GenericArray::from_exact_iter( - y.hi.to_be_bytes().into_iter().chain(y.lo.to_be_bytes()), - ) - .unwrap(), - false, - ), - ); - - if bool::from(point.is_some()) { - Ok(Some(Secp256r1Point { x, y })) - } else { - Ok(None) - } - } - - fn secp256r1_add( - &mut self, - p0: Secp256r1Point, - p1: Secp256r1Point, - _remaining_gas: &mut u128, - ) -> SyscallResult { - // The inner unwraps should be unreachable because the iterator we provide has the expected - // number of bytes. The outer unwraps depend on the felt values, which should be valid since - // they'll be provided by secp256 syscalls. - let p0 = p256::ProjectivePoint::from_encoded_point( - &p256::EncodedPoint::from_affine_coordinates( - &GenericArray::from_exact_iter( - p0.x.hi - .to_be_bytes() - .into_iter() - .chain(p0.x.lo.to_be_bytes()), - ) - .unwrap(), - &GenericArray::from_exact_iter( - p0.y.hi - .to_be_bytes() - .into_iter() - .chain(p0.y.lo.to_be_bytes()), - ) - .unwrap(), - false, - ), - ) - .unwrap(); - let p1 = p256::ProjectivePoint::from_encoded_point( - &p256::EncodedPoint::from_affine_coordinates( - &GenericArray::from_exact_iter( - p1.x.hi - .to_be_bytes() - .into_iter() - .chain(p1.x.lo.to_be_bytes()), - ) - .unwrap(), - &GenericArray::from_exact_iter( - p1.y.hi - .to_be_bytes() - .into_iter() - .chain(p1.y.lo.to_be_bytes()), - ) - .unwrap(), - false, - ), - ) - .unwrap(); - - let p = p0 + p1; - - let p = p.to_encoded_point(false); - let (x, y) = match p.coordinates() { - Coordinates::Uncompressed { x, y } => (x, y), - _ => { - // This should be unreachable because we explicitly asked for the uncompressed - // encoding. - unreachable!() - } - }; - - // The following two unwraps should be safe because the array always has 32 bytes. The other - // four are definitely safe because the slicing guarantees its length to be the right one. - let x: [u8; 32] = x.as_slice().try_into().unwrap(); - let y: [u8; 32] = y.as_slice().try_into().unwrap(); - Ok(Secp256r1Point { - x: U256 { - hi: u128::from_be_bytes(x[0..16].try_into().unwrap()), - lo: u128::from_be_bytes(x[16..32].try_into().unwrap()), - }, - y: U256 { - hi: u128::from_be_bytes(y[0..16].try_into().unwrap()), - lo: u128::from_be_bytes(y[16..32].try_into().unwrap()), - }, - }) - } - - fn secp256r1_mul( - &mut self, - p: Secp256r1Point, - m: U256, - _remaining_gas: &mut u128, - ) -> SyscallResult { - // The inner unwrap should be unreachable because the iterator we provide has the expected - // number of bytes. The outer unwrap depends on the felt values, which should be valid since - // they'll be provided by secp256 syscalls. - let p = p256::ProjectivePoint::from_encoded_point( - &p256::EncodedPoint::from_affine_coordinates( - &GenericArray::from_exact_iter( - p.x.hi.to_be_bytes().into_iter().chain(p.x.lo.to_be_bytes()), - ) - .unwrap(), - &GenericArray::from_exact_iter( - p.y.hi.to_be_bytes().into_iter().chain(p.y.lo.to_be_bytes()), - ) - .unwrap(), - false, - ), - ) - .unwrap(); - let m: p256::Scalar = p256::elliptic_curve::ScalarPrimitive::from_slice(&{ - let mut buf = [0u8; 32]; - buf[0..16].copy_from_slice(&m.hi.to_be_bytes()); - buf[16..32].copy_from_slice(&m.lo.to_be_bytes()); - buf - }) - .map_err(|_| { - vec![Felt::from_bytes_be( - b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0invalid scalar", - )] - })? - .into(); - - let p = p * m; - - let p = p.to_encoded_point(false); - let (x, y) = match p.coordinates() { - Coordinates::Uncompressed { x, y } => (x, y), - _ => { - // This should be unreachable because we explicitly asked for the uncompressed - // encoding. - unreachable!() - } - }; - - // The following two unwraps should be safe because the array always has 32 bytes. The other - // four are definitely safe because the slicing guarantees its length to be the right one. - let x: [u8; 32] = x.as_slice().try_into().unwrap(); - let y: [u8; 32] = y.as_slice().try_into().unwrap(); - Ok(Secp256r1Point { - x: U256 { - hi: u128::from_be_bytes(x[0..16].try_into().unwrap()), - lo: u128::from_be_bytes(x[16..32].try_into().unwrap()), - }, - y: U256 { - hi: u128::from_be_bytes(y[0..16].try_into().unwrap()), - lo: u128::from_be_bytes(y[16..32].try_into().unwrap()), - }, - }) - } - - fn secp256r1_get_point_from_x( - &mut self, - x: U256, - y_parity: bool, - _remaining_gas: &mut u128, - ) -> SyscallResult> { - let point = p256::ProjectivePoint::from_encoded_point( - &p256::EncodedPoint::from_bytes( - p256::CompressedPoint::from_exact_iter( - once(0x02 | y_parity as u8) - .chain(x.hi.to_be_bytes()) - .chain(x.lo.to_be_bytes()), - ) - .unwrap(), - ) - .unwrap(), - ); - - if bool::from(point.is_some()) { - let p = point.unwrap(); - - let p = p.to_encoded_point(false); - let y = match p.coordinates() { - Coordinates::Uncompressed { y, .. } => y, - _ => unreachable!(), - }; - - let y: [u8; 32] = y.as_slice().try_into().unwrap(); - Ok(Some(Secp256r1Point { - x, - y: U256 { - hi: u128::from_be_bytes(y[0..16].try_into().unwrap()), - lo: u128::from_be_bytes(y[16..32].try_into().unwrap()), - }, - })) - } else { - Ok(None) - } - } - - fn secp256r1_get_xy( - &mut self, - p: Secp256r1Point, - _remaining_gas: &mut u128, - ) -> SyscallResult<(U256, U256)> { - Ok((p.x, p.y)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_secp256k1_get_xy() { - let p = Secp256k1Point { - x: U256 { - hi: 331229800296699308591929724809569456681, - lo: 240848751772479376198639683648735950585, - }, - y: U256 { - hi: 75181762170223969696219813306313470806, - lo: 134255467439736302886468555755295925874, - }, - }; - - let mut test_syscall_handler = TestSyscallHandler {}; - - assert_eq!( - test_syscall_handler.secp256k1_get_xy(p, &mut 10).unwrap(), - ( - U256 { - hi: 331229800296699308591929724809569456681, - lo: 240848751772479376198639683648735950585, - }, - U256 { - hi: 75181762170223969696219813306313470806, - lo: 134255467439736302886468555755295925874, - } - ) - ) - } - - #[test] - fn test_secp256k1_secp256k1_new() { - let mut test_syscall_handler = TestSyscallHandler {}; - - let x = U256 { - hi: 97179038819393695679, - lo: 330631467365974629050427735731901850225, - }; - let y = U256 { - hi: 26163136114030451075775058782541084873, - lo: 68974579539311638391577168388077592842, - }; - - assert_eq!( - test_syscall_handler.secp256k1_new(x, y, &mut 10).unwrap(), - Some(Secp256k1Point { x, y }) - ); - } - - #[test] - fn test_secp256k1_secp256k1_new_none() { - let mut test_syscall_handler = TestSyscallHandler {}; - - let x = U256 { - hi: 97179038819393695679, - lo: 330631467365974629050427735731901850225, - }; - let y = U256 { hi: 0, lo: 0 }; - - assert!(test_syscall_handler - .secp256k1_new(x, y, &mut 10) - .unwrap() - .is_none()); - } - - #[test] - fn test_secp256k1_ssecp256k1_add() { - let mut test_syscall_handler = TestSyscallHandler {}; - - let p1 = Secp256k1Point { - x: U256 { - hi: 161825202758953104525843685720298294023, - lo: 3468390537006497937951914270391801752, - }, - y: U256 { - hi: 96009999919712310848645357523629574312, - lo: 336417762351022071123394393598455764152, - }, - }; - - let p2 = p1; - - // 2 * P1 - let p3 = test_syscall_handler.secp256k1_add(p1, p2, &mut 10).unwrap(); - - let p1_double = Secp256k1Point { - x: U256 { - hi: 263210499965038831386353541518668627160, - lo: 122909745026270932982812610085084241637, - }, - y: U256 { - hi: 35730324229579385338853513728577301230, - lo: 329597642124196932058042157271922763050, - }, - }; - assert_eq!(p3, p1_double); - assert_eq!( - test_syscall_handler - .secp256k1_mul(p1, U256 { hi: 0, lo: 2 }, &mut 10) - .unwrap(), - p1_double - ); - - // 3 * P1 - let three_p1 = Secp256k1Point { - x: U256 { - hi: 331229800296699308591929724809569456681, - lo: 240848751772479376198639683648735950585, - }, - y: U256 { - hi: 75181762170223969696219813306313470806, - lo: 134255467439736302886468555755295925874, - }, - }; - assert_eq!( - test_syscall_handler.secp256k1_add(p1, p3, &mut 10).unwrap(), - three_p1 - ); - assert_eq!( - test_syscall_handler - .secp256k1_mul(p1, U256 { hi: 0, lo: 3 }, &mut 10) - .unwrap(), - three_p1 - ); - } - - #[test] - fn test_secp256k1_get_point_from_x_false_yparity() { - let mut test_syscall_handler = TestSyscallHandler {}; - - assert_eq!( - test_syscall_handler - .secp256k1_get_point_from_x( - U256 { - hi: 97179038819393695679, - lo: 330631467365974629050427735731901850225, - }, - false, - &mut 10 - ) - .unwrap() - .unwrap(), - Secp256k1Point { - x: U256 { - hi: 97179038819393695679, - lo: 330631467365974629050427735731901850225, - }, - y: U256 { - hi: 26163136114030451075775058782541084873, - lo: 68974579539311638391577168388077592842 - }, - } - ); - } - - #[test] - fn test_secp256k1_get_point_from_x_true_yparity() { - let mut test_syscall_handler = TestSyscallHandler {}; - - assert_eq!( - test_syscall_handler - .secp256k1_get_point_from_x( - U256 { - hi: 97179038819393695679, - lo: 330631467365974629050427735731901850225, - }, - true, - &mut 10 - ) - .unwrap() - .unwrap(), - Secp256k1Point { - x: U256 { - hi: 97179038819393695679, - lo: 330631467365974629050427735731901850225, - }, - y: U256 { - hi: 314119230806908012387599548649227126582, - lo: 271307787381626825071797439039395650341 - }, - } - ); - } - - #[test] - fn test_secp256k1_get_point_from_x_none() { - let mut test_syscall_handler = TestSyscallHandler {}; - - assert!(test_syscall_handler - .secp256k1_get_point_from_x(U256 { hi: 0, lo: 0 }, true, &mut 10) - .unwrap() - .is_none()); - } - - #[test] - fn test_secp256r1_new() { - let mut test_syscall_handler = TestSyscallHandler {}; - - let x = U256 { - hi: 97179038819393695679, - lo: 330631467365974629050427735731901850225, - }; - let y = U256 { - hi: 118910939004298029402109603132816090461, - lo: 111045440647474106186537215379882575585, - }; - - assert_eq!( - test_syscall_handler - .secp256r1_new(x, y, &mut 10) - .unwrap() - .unwrap(), - Secp256r1Point { x, y } - ); - } - - #[test] - fn test_secp256r1_new_none() { - let mut test_syscall_handler = TestSyscallHandler {}; - - let x = U256 { hi: 0, lo: 0 }; - let y = U256 { hi: 0, lo: 0 }; - - assert!(test_syscall_handler - .secp256r1_new(x, y, &mut 10) - .unwrap() - .is_none()); - } - - #[test] - fn test_secp256r1_add() { - let mut test_syscall_handler = TestSyscallHandler {}; - - let p1 = Secp256r1Point { - x: U256 { - hi: 97179038819393695679, - lo: 330631467365974629050427735731901850225, - }, - y: U256 { - hi: 118910939004298029402109603132816090461, - lo: 111045440647474106186537215379882575585, - }, - }; - - let p2 = p1; - - // 2 * P1 - let p3 = test_syscall_handler.secp256r1_add(p1, p2, &mut 10).unwrap(); - - let p1_double = Secp256r1Point { - x: U256 { - hi: 280079427190737520201067412903899817878, - lo: 309339945874468445579793098896656960879, - }, - y: U256 { - hi: 84249534056490759701994051847937833933, - lo: 231570843221643745062297421862629788481, - }, - }; - assert_eq!(p3, p1_double); - assert_eq!( - test_syscall_handler - .secp256r1_mul(p1, U256 { hi: 0, lo: 2 }, &mut 10) - .unwrap(), - p1_double - ); - - // 3 * P1 - let three_p1 = Secp256r1Point { - x: U256 { - hi: 23850518908906170876551962912581992002, - lo: 195259625777021303662291420857740525307, - }, - y: U256 { - hi: 178681203065513270100417145499857169664, - lo: 282344931843342117515389970197013120959, - }, - }; - assert_eq!( - test_syscall_handler.secp256r1_add(p1, p3, &mut 10).unwrap(), - three_p1 - ); - assert_eq!( - test_syscall_handler - .secp256r1_mul(p1, U256 { hi: 0, lo: 3 }, &mut 10) - .unwrap(), - three_p1 - ); - } - - #[test] - fn test_secp256r1_get_point_from_x_true_yparity() { - let mut test_syscall_handler = TestSyscallHandler {}; - - let x = U256 { - hi: 97179038819393695679, - lo: 330631467365974629050427735731901850225, - }; - - let y = U256 { - hi: 118910939004298029402109603132816090461, - lo: 111045440647474106186537215379882575585, - }; - - assert_eq!( - test_syscall_handler - .secp256r1_get_point_from_x(x, true, &mut 10) - .unwrap() - .unwrap(), - Secp256r1Point { x, y } - ); - } - - #[test] - fn test_secp256r1_get_point_from_x_false_yparity() { - let mut test_syscall_handler = TestSyscallHandler {}; - - let x = U256 { - hi: 97179038819393695679, - lo: 330631467365974629050427735731901850225, - }; - - let y = U256 { - hi: 221371427837412271565447410779117722274, - lo: 229236926352692519791101729645429586206, - }; - - assert_eq!( - test_syscall_handler - .secp256r1_get_point_from_x(x, false, &mut 10) - .unwrap() - .unwrap(), - Secp256r1Point { x, y } - ); - } - - #[test] - fn test_secp256r1_get_point_from_x_none() { - let mut test_syscall_handler = TestSyscallHandler {}; - - let x = U256 { hi: 0, lo: 10 }; - - assert!(test_syscall_handler - .secp256r1_get_point_from_x(x, true, &mut 10) - .unwrap() - .is_none()); - } - - #[test] - fn test_secp256r1_get_xy() { - let p = Secp256r1Point { - x: U256 { - hi: 97179038819393695679, - lo: 330631467365974629050427735731901850225, - }, - y: U256 { - hi: 221371427837412271565447410779117722274, - lo: 229236926352692519791101729645429586206, - }, - }; - - let mut test_syscall_handler = TestSyscallHandler {}; - - assert_eq!( - test_syscall_handler.secp256r1_get_xy(p, &mut 10).unwrap(), - ( - U256 { - hi: 97179038819393695679, - lo: 330631467365974629050427735731901850225, - }, - U256 { - hi: 221371427837412271565447410779117722274, - lo: 229236926352692519791101729645429586206, - } - ) - ) - } -} diff --git a/src/bin/scarb-native-test.rs b/src/bin/scarb-native-test.rs new file mode 100644 index 000000000..bbdeb299b --- /dev/null +++ b/src/bin/scarb-native-test.rs @@ -0,0 +1,98 @@ +mod utils; + +use std::{env, fs}; + +use anyhow::Context; +use cairo_lang_test_plugin::TestCompilation; +use clap::Parser; +use scarb_metadata::{Metadata, MetadataCommand, ScarbCommand}; +use scarb_ui::args::PackagesFilter; +use utils::test::{display_tests_summary, filter_test_cases, find_testable_targets, run_tests}; +use utils::{RunArgs, RunMode}; + +/// Compiles all packages from a Scarb project matching `packages_filter` and +/// runs all functions marked with `#[test]`. Exits with 1 if the compilation +/// or run fails, otherwise 0. +#[derive(Parser, Clone, Debug)] +#[command(author, version, verbatim_doc_comment)] +struct Args { + #[command(flatten)] + packages_filter: PackagesFilter, + /// Run only tests whose name contain FILTER. + #[arg(short, long, default_value = "")] + filter: String, + /// Run ignored and not ignored tests. + #[arg(long, default_value_t = false)] + include_ignored: bool, + /// Run only ignored tests. + #[arg(long, default_value_t = false)] + ignored: bool, + /// Run with JIT or AOT (compiled). + #[arg(long, value_enum, default_value_t = RunMode::Jit)] + run_mode: RunMode, + /// Optimization level, Valid: 0, 1, 2, 3. Values higher than 3 are considered as 3. + #[arg(short = 'O', long, default_value_t = 0)] + opt_level: u8, +} + +fn main() -> anyhow::Result<()> { + let args: Args = Args::parse(); + + let metadata = MetadataCommand::new().inherit_stderr().exec()?; + + // Filter packages. + let matched = args.packages_filter.match_many(&metadata)?; + let filter = PackagesFilter::generate_for::(matched.iter()); + + // Build only the filtered packages. + ScarbCommand::new() + .arg("build") + .arg("--test") + .env("SCARB_PACKAGES_FILTER", filter.to_env()) + .run()?; + + // Get `target` directory. + let profile = env::var("SCARB_PROFILE").unwrap_or("dev".into()); + let default_target_dir = metadata.runtime_manifest.join("target"); + let target_dir = metadata + .target_dir + .clone() + .unwrap_or(default_target_dir) + .join(profile); + + // Iterate over the filtered packages. + for package in matched { + println!("testing {} ...", package.name); + + // Iterate over the filtered targets. + for target in find_testable_targets(&package) { + let file_path = target_dir.join(format!("{}.test.json", target.name.clone())); + let compiled = serde_json::from_str::( + &fs::read_to_string(file_path.clone()) + .with_context(|| format!("failed to read file: {file_path}"))?, + ) + .with_context(|| format!("failed to deserialize compiled tests file: {file_path}"))?; + + let (compiled, filtered_out) = filter_test_cases( + compiled, + args.include_ignored, + args.ignored, + args.filter.clone(), + ); + + let summary = run_tests( + compiled.named_tests, + compiled.sierra_program, + compiled.function_set_costs, + RunArgs { + run_mode: args.run_mode.clone(), + opt_level: args.opt_level, + }, + )?; + + display_tests_summary(&summary, filtered_out); + } + } + + Ok(()) +} diff --git a/src/bin/utils/mod.rs b/src/bin/utils/mod.rs new file mode 100644 index 000000000..dd13533ae --- /dev/null +++ b/src/bin/utils/mod.rs @@ -0,0 +1,473 @@ +#![cfg(feature = "build-cli")] +#![allow(dead_code)] + +pub mod test; + +use anyhow::bail; +use cairo_felt::Felt252; +use cairo_lang_runner::{casm_run::format_next_item, RunResultValue}; +use cairo_lang_sierra::program::{Function, Program}; +use cairo_native::{execution_result::ExecutionResult, values::JitValue}; +use clap::ValueEnum; +use itertools::Itertools; +use starknet_types_core::felt::Felt; +use std::vec::IntoIter; + +pub(super) struct RunArgs { + pub run_mode: RunMode, + pub opt_level: u8, +} + +#[derive(Clone, Debug, ValueEnum)] +pub enum RunMode { + Aot, + Jit, +} + +/// Find the function ending with `name_suffix` in the program. +pub fn find_function<'a>( + sierra_program: &'a Program, + name_suffix: &str, +) -> anyhow::Result<&'a Function> { + if let Some(x) = sierra_program.funcs.iter().find(|f| { + if let Some(name) = &f.id.debug_name { + name.ends_with(name_suffix) + } else { + false + } + }) { + Ok(x) + } else { + bail!("test function not found") + } +} + +/// Formats the given felts as a panic string. +pub fn format_for_panic(mut felts: IntoIter) -> String { + let mut items = Vec::new(); + while let Some(item) = format_next_item(&mut felts) { + items.push(item.quote_if_string()); + } + let panic_values_string = if let [item] = &items[..] { + item.clone() + } else { + format!("({})", items.join(", ")) + }; + format!("Panicked with {panic_values_string}.") +} + +/// Convert the execution result to a run result. +pub fn result_to_runresult(result: &ExecutionResult) -> anyhow::Result { + let is_success; + let mut felts: Vec = Vec::new(); + + match &result.return_value { + outer_value @ JitValue::Enum { + tag, + value, + debug_name, + } => { + if debug_name + .as_ref() + .expect("missing debug name") + .starts_with("core::panics::PanicResult::") + { + is_success = *tag == 0; + + if !is_success { + match &**value { + JitValue::Struct { fields, .. } => { + for field in fields { + let felt = jitvalue_to_felt(field); + felts.extend(felt); + } + } + _ => bail!("unsuported return value in cairo-native"), + } + } else { + felts.extend(jitvalue_to_felt(value)); + } + } else { + is_success = true; + felts.extend(jitvalue_to_felt(outer_value)); + } + } + x => { + is_success = true; + felts.extend(jitvalue_to_felt(x)); + } + } + + let return_values = felts + .into_iter() + .map(|x| x.to_bigint().into()) + .collect_vec(); + + Ok(match is_success { + true => RunResultValue::Success(return_values), + false => RunResultValue::Panic(return_values), + }) +} + +/// Convert a JIT value to a felt. +fn jitvalue_to_felt(value: &JitValue) -> Vec { + let mut felts = Vec::new(); + match value { + JitValue::Felt252(felt) => vec![*felt], + JitValue::BoundedInt { value, .. } => vec![*value], + JitValue::Array(fields) | JitValue::Struct { fields, .. } => { + fields.iter().flat_map(jitvalue_to_felt).collect() + } + JitValue::Enum { + value, + tag, + debug_name, + } => { + if let Some(debug_name) = debug_name { + if debug_name == "core::bool" { + vec![(*tag == 1).into()] + } else { + let mut felts = vec![(*tag).into()]; + felts.extend(jitvalue_to_felt(value)); + felts + } + } else { + todo!() + } + } + JitValue::Felt252Dict { value, .. } => { + for (key, value) in value { + felts.push(*key); + let felt = jitvalue_to_felt(value); + felts.extend(felt); + } + + felts + } + JitValue::Uint8(x) => vec![(*x).into()], + JitValue::Uint16(x) => vec![(*x).into()], + JitValue::Uint32(x) => vec![(*x).into()], + JitValue::Uint64(x) => vec![(*x).into()], + JitValue::Uint128(x) => vec![(*x).into()], + JitValue::Sint8(x) => vec![(*x).into()], + JitValue::Sint16(x) => vec![(*x).into()], + JitValue::Sint32(x) => vec![(*x).into()], + JitValue::Sint64(x) => vec![(*x).into()], + JitValue::Sint128(x) => vec![(*x).into()], + JitValue::Bytes31(_) + | JitValue::EcPoint(_, _) + | JitValue::EcState(_, _, _, _) + | JitValue::Secp256K1Point { .. } + | JitValue::Secp256R1Point { .. } => todo!(), + JitValue::Null => vec![0.into()], + } +} + +#[cfg(test)] +mod tests { + use super::*; + use cairo_felt::Felt252; + use cairo_lang_sierra::ProgramParser; + + #[test] + fn test_find_function() { + // Parse a simple program containing a function named "Func2" + let program = ProgramParser::new().parse("Func2@6() -> ();").unwrap(); + + // Assert that the function "Func2" is found and returned correctly + assert_eq!( + find_function(&program, "Func2").unwrap(), + program.funcs.first().unwrap() + ); + + // Assert that an error is returned when trying to find a non-existing function "Func3" + assert!(find_function(&program, "Func3").is_err()); + + // Assert that an error is returned when trying to find a function in an empty program + assert!(find_function(&ProgramParser::new().parse("").unwrap(), "Func2").is_err()); + } + + #[test] + fn test_result_to_runresult_enum_nonpanic() { + // Tests the conversion of a non-panic enum result to a `RunResultValue::Success`. + assert_eq!( + result_to_runresult(&ExecutionResult { + remaining_gas: None, + return_value: JitValue::Enum { + tag: 34, + value: JitValue::Array(vec![ + JitValue::Felt252(42.into()), + JitValue::Uint8(100), + JitValue::Uint128(1000), + ]) + .into(), + debug_name: Some("debug_name".into()), + }, + builtin_stats: Default::default(), + }) + .unwrap(), + RunResultValue::Success(vec![ + Felt252::from(34), + Felt252::from(42), + Felt252::from(100), + Felt252::from(1000) + ]) + ); + } + + #[test] + fn test_result_to_runresult_success() { + // Tests the conversion of a success enum result to a `RunResultValue::Success`. + assert_eq!( + result_to_runresult(&ExecutionResult { + remaining_gas: None, + return_value: JitValue::Enum { + tag: 0, + value: JitValue::Uint64(24).into(), + debug_name: Some("core::panics::PanicResult::Test".into()), + }, + builtin_stats: Default::default(), + }) + .unwrap(), + RunResultValue::Success(vec![Felt252::from(24)]) + ); + } + + #[test] + #[should_panic(expected = "unsuported return value in cairo-native")] + fn test_result_to_runresult_panic() { + // Tests the conversion with unsuported return value. + let _ = result_to_runresult(&ExecutionResult { + remaining_gas: None, + return_value: JitValue::Enum { + tag: 10, + value: JitValue::Uint64(24).into(), + debug_name: Some("core::panics::PanicResult::Test".into()), + }, + builtin_stats: Default::default(), + }) + .unwrap(); + } + + #[test] + #[should_panic(expected = "missing debug name")] + fn test_result_to_runresult_missing_debug_name() { + // Tests the conversion with no debug name. + let _ = result_to_runresult(&ExecutionResult { + remaining_gas: None, + return_value: JitValue::Enum { + tag: 10, + value: JitValue::Uint64(24).into(), + debug_name: None, + }, + builtin_stats: Default::default(), + }) + .unwrap(); + } + + #[test] + fn test_result_to_runresult_return() { + // Tests the conversion of a panic enum result with non-zero tag to a `RunResultValue::Panic`. + assert_eq!( + result_to_runresult(&ExecutionResult { + remaining_gas: None, + return_value: JitValue::Enum { + tag: 10, + value: JitValue::Struct { + fields: vec![ + JitValue::Felt252(42.into()), + JitValue::Uint8(100), + JitValue::Uint128(1000), + ], + debug_name: Some("debug_name".into()), + } + .into(), + debug_name: Some("core::panics::PanicResult::Test".into()), + }, + builtin_stats: Default::default(), + }) + .unwrap(), + RunResultValue::Panic(vec![ + Felt252::from(42), + Felt252::from(100), + Felt252::from(1000) + ]) + ); + } + + #[test] + fn test_result_to_runresult_non_enum() { + // Tests the conversion of a non-enum result to a `RunResultValue::Success`. + assert_eq!( + result_to_runresult(&ExecutionResult { + remaining_gas: None, + return_value: JitValue::Uint8(10), + builtin_stats: Default::default(), + }) + .unwrap(), + RunResultValue::Success(vec![Felt252::from(10)]) + ); + } + + #[test] + fn test_jitvalue_to_felt_felt252() { + let felt_value: Felt = 42.into(); + + assert_eq!( + jitvalue_to_felt(&JitValue::Felt252(felt_value)), + vec![felt_value] + ); + } + + #[test] + fn test_jitvalue_to_felt_array() { + assert_eq!( + jitvalue_to_felt(&JitValue::Array(vec![ + JitValue::Felt252(42.into()), + JitValue::Uint8(100), + JitValue::Uint128(1000), + ])), + vec![Felt::from(42), Felt::from(100), Felt::from(1000)] + ); + } + + #[test] + fn test_jitvalue_to_felt_struct() { + assert_eq!( + jitvalue_to_felt(&JitValue::Struct { + fields: vec![ + JitValue::Felt252(42.into()), + JitValue::Uint8(100), + JitValue::Uint128(1000) + ], + debug_name: Some("debug_name".into()) + }), + vec![Felt::from(42), Felt::from(100), Felt::from(1000)] + ); + } + + #[test] + fn test_jitvalue_to_felt_enum() { + // With debug name + assert_eq!( + jitvalue_to_felt(&JitValue::Enum { + tag: 34, + value: JitValue::Array(vec![ + JitValue::Felt252(42.into()), + JitValue::Uint8(100), + JitValue::Uint128(1000), + ]) + .into(), + debug_name: Some("debug_name".into()) + }), + vec![ + Felt::from(34), + Felt::from(42), + Felt::from(100), + Felt::from(1000) + ] + ); + + // With core::bool debug name and tag 1 + assert_eq!( + jitvalue_to_felt(&JitValue::Enum { + tag: 1, + value: JitValue::Uint128(1000).into(), + debug_name: Some("core::bool".into()) + }), + vec![Felt::ONE] + ); + + // With core::bool debug name and tag not 1 + assert_eq!( + jitvalue_to_felt(&JitValue::Enum { + tag: 10, + value: JitValue::Uint128(1000).into(), + debug_name: Some("core::bool".into()) + }), + vec![Felt::ZERO] + ); + } + + #[test] + fn test_jitvalue_to_felt_u8() { + assert_eq!(jitvalue_to_felt(&JitValue::Uint8(10)), vec![Felt::from(10)]); + } + + #[test] + fn test_jitvalue_to_felt_u16() { + assert_eq!( + jitvalue_to_felt(&JitValue::Uint16(100)), + vec![Felt::from(100)] + ); + } + + #[test] + fn test_jitvalue_to_felt_u32() { + assert_eq!( + jitvalue_to_felt(&JitValue::Uint32(1000)), + vec![Felt::from(1000)] + ); + } + + #[test] + fn test_jitvalue_to_felt_u64() { + assert_eq!( + jitvalue_to_felt(&JitValue::Uint64(10000)), + vec![Felt::from(10000)] + ); + } + + #[test] + fn test_jitvalue_to_felt_u128() { + assert_eq!( + jitvalue_to_felt(&JitValue::Uint128(100000)), + vec![Felt::from(100000)] + ); + } + + #[test] + fn test_jitvalue_to_felt_sint8() { + assert_eq!( + jitvalue_to_felt(&JitValue::Sint8(-10)), + vec![Felt::from(-10)] + ); + } + + #[test] + fn test_jitvalue_to_felt_sint16() { + assert_eq!( + jitvalue_to_felt(&JitValue::Sint16(-100)), + vec![Felt::from(-100)] + ); + } + + #[test] + fn test_jitvalue_to_felt_sint32() { + assert_eq!( + jitvalue_to_felt(&JitValue::Sint32(-1000)), + vec![Felt::from(-1000)] + ); + } + + #[test] + fn test_jitvalue_to_felt_sint64() { + assert_eq!( + jitvalue_to_felt(&JitValue::Sint64(-10000)), + vec![Felt::from(-10000)] + ); + } + + #[test] + fn test_jitvalue_to_felt_sint128() { + assert_eq!( + jitvalue_to_felt(&JitValue::Sint128(-100000)), + vec![Felt::from(-100000)] + ); + } + + #[test] + fn test_jitvalue_to_felt_null() { + assert_eq!(jitvalue_to_felt(&JitValue::Null), vec![Felt::ZERO]); + } +} diff --git a/src/bin/utils/test.rs b/src/bin/utils/test.rs new file mode 100644 index 000000000..12bcc2508 --- /dev/null +++ b/src/bin/utils/test.rs @@ -0,0 +1,1376 @@ +use super::{find_function, format_for_panic, result_to_runresult, RunArgs, RunMode}; +use anyhow::Context; +use cairo_lang_runner::RunResultValue; +use cairo_lang_sierra::program::Program; +use cairo_lang_sierra::{extensions::gas::CostTokenType, ids::FunctionId}; +use cairo_lang_test_plugin::TestCompilation; +use cairo_lang_test_plugin::{ + test_config::{PanicExpectation, TestExpectation}, + TestConfig, +}; +use cairo_lang_utils::casts::IntoOrPanic; +use cairo_lang_utils::ordered_hash_map::OrderedHashMap; +use cairo_native::starknet::{Secp256r1Point, SyscallResult, U256}; +use cairo_native::{ + context::NativeContext, + executor::{AotNativeExecutor, JitNativeExecutor, NativeExecutor}, + metadata::gas::{GasMetadata, MetadataComputationConfig}, + starknet::{Secp256k1Point, StarknetSyscallHandler}, +}; +use colored::Colorize; +use itertools::Itertools; +use k256::elliptic_curve::sec1::ToEncodedPoint; +use k256::elliptic_curve::{generic_array::GenericArray, sec1::FromEncodedPoint}; +use num_traits::ToPrimitive; +#[cfg(feature = "scarb")] +use scarb_metadata::{PackageMetadata, TargetMetadata}; +use sec1::point::Coordinates; +use starknet_types_core::felt::Felt; +use std::{iter::once, sync::Mutex}; + +/// Summary data of the ran tests. +pub struct TestsSummary { + pub passed: Vec, + pub failed: Vec, + pub ignored: Vec, + pub failed_run_results: Vec, +} + +/// The result of a ran test. +struct TestResult { + /// The status of the run. + status: TestStatus, + /// The gas usage of the run if relevant. + gas_usage: Option, +} + +/// The status of a ran test. +enum TestStatus { + Success, + Fail(RunResultValue), +} + +/// Find all testable targets in the Scarb package. +#[cfg(feature = "scarb")] +pub fn find_testable_targets(package: &PackageMetadata) -> Vec<&TargetMetadata> { + package + .targets + .iter() + .filter(|target| target.kind == "test") + .collect() +} + +/// Filter compiled test cases with user provided arguments. +/// +/// # Arguments +/// * `compiled` - Compiled test cases with metadata. +/// * `include_ignored` - Include ignored tests as well. +/// * `ignored` - Run ignored tests only. +/// * `filter` - Include only tests containing the filter string. +/// # Returns +/// * (`TestCompilation`, `usize`) - The filtered test cases and the number of filtered out cases. +pub fn filter_test_cases( + compiled: TestCompilation, + include_ignored: bool, + ignored: bool, + filter: String, +) -> (TestCompilation, usize) { + let total_tests_count = compiled.named_tests.len(); + let named_tests = compiled + .named_tests + .into_iter() + // Filtering unignored tests in `ignored` mode + .filter(|(_, test)| !ignored || test.ignored || include_ignored) + .map(|(func, mut test)| { + // Un-ignoring all the tests in `include-ignored` and `ignored` mode. + if include_ignored || ignored { + test.ignored = false; + } + (func, test) + }) + .filter(|(name, _)| name.contains(&filter)) + .collect_vec(); + let filtered_out = total_tests_count - named_tests.len(); + let tests = TestCompilation { + named_tests, + ..compiled + }; + (tests, filtered_out) +} + +/// Display the summary of the ran tests. +pub fn display_tests_summary(summary: &TestsSummary, filtered_out: usize) { + if summary.failed.is_empty() { + println!( + "test result: {}. {} passed; {} failed; {} ignored; {filtered_out} filtered out;", + "ok".bright_green(), + summary.passed.len(), + summary.failed.len(), + summary.ignored.len() + ); + } else { + println!("failures:"); + for (failure, run_result) in summary + .failed + .iter() + .zip_eq(summary.failed_run_results.clone()) + { + print!(" {failure} - "); + match run_result { + RunResultValue::Success(_) => { + println!("expected panic but finished successfully."); + } + RunResultValue::Panic(values) => { + println!("{}", format_for_panic(values.into_iter())); + } + } + } + println!(); + } +} + +/// Runs the tests and process the results for a summary. +pub fn run_tests( + named_tests: Vec<(String, TestConfig)>, + sierra_program: Program, + function_set_costs: OrderedHashMap>, + args: RunArgs, +) -> anyhow::Result { + let native_context = NativeContext::new(); + + // Compile the sierra program into a MLIR module. + let native_module = native_context + .compile_with_metadata( + &sierra_program, + MetadataComputationConfig { + function_set_costs: function_set_costs.clone(), + linear_ap_change_solver: true, + linear_gas_solver: true, + }, + ) + .unwrap(); + + let native_executor: NativeExecutor = match args.run_mode { + RunMode::Aot => { + AotNativeExecutor::from_native_module(native_module, args.opt_level.into()).into() + } + RunMode::Jit => { + JitNativeExecutor::from_native_module(native_module, args.opt_level.into()).into() + } + }; + + let gas_metadata = GasMetadata::new( + &sierra_program, + Some(MetadataComputationConfig { + function_set_costs, + linear_ap_change_solver: true, + linear_gas_solver: true, + }), + ) + .unwrap(); + + println!("running {} tests", named_tests.len()); + let wrapped_summary = Mutex::new(Ok(TestsSummary { + passed: vec![], + failed: vec![], + ignored: vec![], + failed_run_results: vec![], + })); + named_tests + .into_iter() + .map( + |(name, test)| -> anyhow::Result<(String, Option)> { + if test.ignored { + return Ok((name, None)); + } + tracing::trace!("running test {name:?}"); + + let func = find_function(&sierra_program, name.as_str())?; + + let initial_gas = test.available_gas.map(|x| x.try_into().unwrap()); + + let result = native_executor + .invoke_dynamic_with_syscall_handler( + &func.id, + &[], + initial_gas, + TestSyscallHandler, + ) + .with_context(|| format!("Failed to run the function `{}`.", name.as_str()))?; + + let run_result = result_to_runresult(&result)?; + Ok(( + name, + Some(TestResult { + status: match &run_result { + RunResultValue::Success(_) => match test.expectation { + TestExpectation::Success => TestStatus::Success, + TestExpectation::Panics(_) => TestStatus::Fail(run_result), + }, + RunResultValue::Panic(value) => match test.expectation { + TestExpectation::Success => TestStatus::Fail(run_result), + TestExpectation::Panics(panic_expectation) => { + match panic_expectation { + PanicExpectation::Exact(expected) if value != &expected => { + TestStatus::Fail(run_result) + } + _ => TestStatus::Success, + } + } + }, + }, + gas_usage: test + .available_gas + .zip(result.remaining_gas) + .map(|(before, after)| { + before.into_or_panic::() - after.to_i64().unwrap() + }) + .or_else(|| { + gas_metadata + .initial_required_gas(&func.id) + .map(|gas| gas.try_into().unwrap()) + }), + }), + )) + }, + ) + .for_each(|r| { + let mut wrapped_summary = wrapped_summary.lock().unwrap(); + if wrapped_summary.is_err() { + return; + } + let (name, status) = match r { + Ok((name, status)) => (name, status), + Err(err) => { + *wrapped_summary = Err(err); + return; + } + }; + let summary = wrapped_summary.as_mut().unwrap(); + let (res_type, status_str, gas_usage) = match status { + Some(TestResult { + status: TestStatus::Success, + gas_usage, + }) => (&mut summary.passed, "ok".bright_green(), gas_usage), + Some(TestResult { + status: TestStatus::Fail(run_result), + gas_usage, + }) => { + summary.failed_run_results.push(run_result); + (&mut summary.failed, "fail".bright_red(), gas_usage) + } + None => (&mut summary.ignored, "ignored".bright_yellow(), None), + }; + if let Some(gas_usage) = gas_usage { + println!("test {name} ... {status_str} (gas usage est.: {gas_usage})"); + } else { + println!("test {name} ... {status_str}"); + } + res_type.push(name); + }); + wrapped_summary.into_inner().unwrap() +} + +pub struct TestSyscallHandler; + +impl StarknetSyscallHandler for TestSyscallHandler { + fn get_block_hash( + &mut self, + _block_number: u64, + _remaining_gas: &mut u128, + ) -> SyscallResult { + unimplemented!() + } + + fn get_execution_info( + &mut self, + _remaining_gas: &mut u128, + ) -> SyscallResult { + unimplemented!() + } + + fn get_execution_info_v2( + &mut self, + _remaining_gas: &mut u128, + ) -> SyscallResult { + unimplemented!() + } + + fn deploy( + &mut self, + _class_hash: Felt, + _contract_address_salt: Felt, + _calldata: &[Felt], + _deploy_from_zero: bool, + _remaining_gas: &mut u128, + ) -> SyscallResult<(Felt, Vec)> { + unimplemented!() + } + + fn replace_class(&mut self, _class_hash: Felt, _remaining_gas: &mut u128) -> SyscallResult<()> { + unimplemented!() + } + + fn library_call( + &mut self, + _class_hash: Felt, + _function_selector: Felt, + _calldata: &[Felt], + _remaining_gas: &mut u128, + ) -> SyscallResult> { + unimplemented!() + } + + fn call_contract( + &mut self, + _address: Felt, + _entry_point_selector: Felt, + _calldata: &[Felt], + _remaining_gas: &mut u128, + ) -> SyscallResult> { + unimplemented!() + } + + fn storage_read( + &mut self, + _address_domain: u32, + _address: Felt, + _remaining_gas: &mut u128, + ) -> SyscallResult { + unimplemented!() + } + + fn storage_write( + &mut self, + _address_domain: u32, + _address: Felt, + _value: Felt, + _remaining_gas: &mut u128, + ) -> SyscallResult<()> { + unimplemented!() + } + + fn emit_event( + &mut self, + _keys: &[Felt], + _data: &[Felt], + _remaining_gas: &mut u128, + ) -> SyscallResult<()> { + unimplemented!() + } + + fn send_message_to_l1( + &mut self, + _to_address: Felt, + _payload: &[Felt], + _remaining_gas: &mut u128, + ) -> SyscallResult<()> { + unimplemented!() + } + + fn keccak(&mut self, input: &[u64], gas: &mut u128) -> SyscallResult { + let length = input.len(); + + if length % 17 != 0 { + let error_msg = b"Invalid keccak input size"; + let felt_error = Felt::from_bytes_be_slice(error_msg); + return Err(vec![felt_error]); + } + + let n_chunks = length / 17; + let mut state = [0u64; 25]; + + for i in 0..n_chunks { + if *gas < KECCAK_ROUND_COST { + let error_msg = b"Syscall out of gas"; + let felt_error = Felt::from_bytes_be_slice(error_msg); + return Err(vec![felt_error]); + } + const KECCAK_ROUND_COST: u128 = 180000; + *gas -= KECCAK_ROUND_COST; + let chunk = &input[i * 17..(i + 1) * 17]; //(request.input_start + i * 17)?; + for (i, val) in chunk.iter().enumerate() { + state[i] ^= val; + } + keccak::f1600(&mut state) + } + + // state[0] and state[1] conform the hash_high (u128) + // state[2] and state[3] conform the hash_low (u128) + SyscallResult::Ok(U256 { + lo: state[2] as u128 | ((state[3] as u128) << 64), + hi: state[0] as u128 | ((state[1] as u128) << 64), + }) + } + + fn secp256k1_new( + &mut self, + x: U256, + y: U256, + _remaining_gas: &mut u128, + ) -> SyscallResult> { + // The following unwraps should be unreachable because the iterator we provide has the + // expected number of bytes. + let point = k256::ProjectivePoint::from_encoded_point( + &k256::EncodedPoint::from_affine_coordinates( + &GenericArray::from_exact_iter( + x.hi.to_be_bytes().into_iter().chain(x.lo.to_be_bytes()), + ) + .unwrap(), + &GenericArray::from_exact_iter( + y.hi.to_be_bytes().into_iter().chain(y.lo.to_be_bytes()), + ) + .unwrap(), + false, + ), + ); + + if bool::from(point.is_some()) { + Ok(Some(Secp256k1Point { x, y })) + } else { + Ok(None) + } + } + + fn secp256k1_add( + &mut self, + p0: Secp256k1Point, + p1: Secp256k1Point, + _remaining_gas: &mut u128, + ) -> SyscallResult { + // The inner unwraps should be unreachable because the iterator we provide has the expected + // number of bytes. The outer unwraps depend on the felt values, which should be valid since + // they'll be provided by secp256 syscalls. + let p0 = k256::ProjectivePoint::from_encoded_point( + &k256::EncodedPoint::from_affine_coordinates( + &GenericArray::from_exact_iter( + p0.x.hi + .to_be_bytes() + .into_iter() + .chain(p0.x.lo.to_be_bytes()), + ) + .unwrap(), + &GenericArray::from_exact_iter( + p0.y.hi + .to_be_bytes() + .into_iter() + .chain(p0.y.lo.to_be_bytes()), + ) + .unwrap(), + false, + ), + ) + .unwrap(); + let p1 = k256::ProjectivePoint::from_encoded_point( + &k256::EncodedPoint::from_affine_coordinates( + &GenericArray::from_exact_iter( + p1.x.hi + .to_be_bytes() + .into_iter() + .chain(p1.x.lo.to_be_bytes()), + ) + .unwrap(), + &GenericArray::from_exact_iter( + p1.y.hi + .to_be_bytes() + .into_iter() + .chain(p1.y.lo.to_be_bytes()), + ) + .unwrap(), + false, + ), + ) + .unwrap(); + + let p = p0 + p1; + + let p = p.to_encoded_point(false); + let (x, y) = match p.coordinates() { + Coordinates::Uncompressed { x, y } => (x, y), + _ => { + // This should be unreachable because we explicitly asked for the uncompressed + // encoding. + unreachable!() + } + }; + + // The following two unwraps should be safe because the array always has 32 bytes. The other + // four are definitely safe because the slicing guarantees its length to be the right one. + let x: [u8; 32] = x.as_slice().try_into().unwrap(); + let y: [u8; 32] = y.as_slice().try_into().unwrap(); + Ok(Secp256k1Point { + x: U256 { + hi: u128::from_be_bytes(x[0..16].try_into().unwrap()), + lo: u128::from_be_bytes(x[16..32].try_into().unwrap()), + }, + y: U256 { + hi: u128::from_be_bytes(y[0..16].try_into().unwrap()), + lo: u128::from_be_bytes(y[16..32].try_into().unwrap()), + }, + }) + } + + fn secp256k1_mul( + &mut self, + p: Secp256k1Point, + m: U256, + _remaining_gas: &mut u128, + ) -> SyscallResult { + // The inner unwrap should be unreachable because the iterator we provide has the expected + // number of bytes. The outer unwrap depends on the felt values, which should be valid since + // they'll be provided by secp256 syscalls. + let p = k256::ProjectivePoint::from_encoded_point( + &k256::EncodedPoint::from_affine_coordinates( + &GenericArray::from_exact_iter( + p.x.hi.to_be_bytes().into_iter().chain(p.x.lo.to_be_bytes()), + ) + .unwrap(), + &GenericArray::from_exact_iter( + p.y.hi.to_be_bytes().into_iter().chain(p.y.lo.to_be_bytes()), + ) + .unwrap(), + false, + ), + ) + .unwrap(); + let m: k256::Scalar = k256::elliptic_curve::ScalarPrimitive::from_slice(&{ + let mut buf = [0u8; 32]; + buf[0..16].copy_from_slice(&m.hi.to_be_bytes()); + buf[16..32].copy_from_slice(&m.lo.to_be_bytes()); + buf + }) + .map_err(|_| { + vec![Felt::from_bytes_be( + b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0invalid scalar", + )] + })? + .into(); + + let p = p * m; + + let p = p.to_encoded_point(false); + let (x, y) = match p.coordinates() { + Coordinates::Uncompressed { x, y } => (x, y), + _ => { + // This should be unreachable because we explicitly asked for the uncompressed + // encoding. + unreachable!() + } + }; + + // The following two unwraps should be safe because the array always has 32 bytes. The other + // four are definitely safe because the slicing guarantees its length to be the right one. + let x: [u8; 32] = x.as_slice().try_into().unwrap(); + let y: [u8; 32] = y.as_slice().try_into().unwrap(); + Ok(Secp256k1Point { + x: U256 { + hi: u128::from_be_bytes(x[0..16].try_into().unwrap()), + lo: u128::from_be_bytes(x[16..32].try_into().unwrap()), + }, + y: U256 { + hi: u128::from_be_bytes(y[0..16].try_into().unwrap()), + lo: u128::from_be_bytes(y[16..32].try_into().unwrap()), + }, + }) + } + + fn secp256k1_get_point_from_x( + &mut self, + x: U256, + y_parity: bool, + _remaining_gas: &mut u128, + ) -> SyscallResult> { + // The inner unwrap should be unreachable because the iterator we provide has the expected + // number of bytes. The outer unwrap depends on the encoding format, which should be valid + // since it's hardcoded.. + let point = k256::ProjectivePoint::from_encoded_point( + &k256::EncodedPoint::from_bytes( + k256::CompressedPoint::from_exact_iter( + once(0x02 | y_parity as u8) + .chain(x.hi.to_be_bytes()) + .chain(x.lo.to_be_bytes()), + ) + .unwrap(), + ) + .unwrap(), + ); + + if bool::from(point.is_some()) { + // This unwrap has already been checked in the `if` expression's condition. + let p = point.unwrap(); + + let p = p.to_encoded_point(false); + let y = match p.coordinates() { + Coordinates::Uncompressed { y, .. } => y, + _ => { + // This should be unreachable because we explicitly asked for the uncompressed + // encoding. + unreachable!() + } + }; + + // The following unwrap should be safe because the array always has 32 bytes. The other + // two are definitely safe because the slicing guarantees its length to be the right + // one. + let y: [u8; 32] = y.as_slice().try_into().unwrap(); + Ok(Some(Secp256k1Point { + x, + y: U256 { + hi: u128::from_be_bytes(y[0..16].try_into().unwrap()), + lo: u128::from_be_bytes(y[16..32].try_into().unwrap()), + }, + })) + } else { + Ok(None) + } + } + + fn secp256k1_get_xy( + &mut self, + p: Secp256k1Point, + _remaining_gas: &mut u128, + ) -> SyscallResult<(U256, U256)> { + Ok((p.x, p.y)) + } + + fn secp256r1_new( + &mut self, + x: U256, + y: U256, + _remaining_gas: &mut u128, + ) -> SyscallResult> { + // The following unwraps should be unreachable because the iterator we provide has the + // expected number of bytes. + let point = p256::ProjectivePoint::from_encoded_point( + &k256::EncodedPoint::from_affine_coordinates( + &GenericArray::from_exact_iter( + x.hi.to_be_bytes().into_iter().chain(x.lo.to_be_bytes()), + ) + .unwrap(), + &GenericArray::from_exact_iter( + y.hi.to_be_bytes().into_iter().chain(y.lo.to_be_bytes()), + ) + .unwrap(), + false, + ), + ); + + if bool::from(point.is_some()) { + Ok(Some(Secp256r1Point { x, y })) + } else { + Ok(None) + } + } + + fn secp256r1_add( + &mut self, + p0: Secp256r1Point, + p1: Secp256r1Point, + _remaining_gas: &mut u128, + ) -> SyscallResult { + // The inner unwraps should be unreachable because the iterator we provide has the expected + // number of bytes. The outer unwraps depend on the felt values, which should be valid since + // they'll be provided by secp256 syscalls. + let p0 = p256::ProjectivePoint::from_encoded_point( + &p256::EncodedPoint::from_affine_coordinates( + &GenericArray::from_exact_iter( + p0.x.hi + .to_be_bytes() + .into_iter() + .chain(p0.x.lo.to_be_bytes()), + ) + .unwrap(), + &GenericArray::from_exact_iter( + p0.y.hi + .to_be_bytes() + .into_iter() + .chain(p0.y.lo.to_be_bytes()), + ) + .unwrap(), + false, + ), + ) + .unwrap(); + let p1 = p256::ProjectivePoint::from_encoded_point( + &p256::EncodedPoint::from_affine_coordinates( + &GenericArray::from_exact_iter( + p1.x.hi + .to_be_bytes() + .into_iter() + .chain(p1.x.lo.to_be_bytes()), + ) + .unwrap(), + &GenericArray::from_exact_iter( + p1.y.hi + .to_be_bytes() + .into_iter() + .chain(p1.y.lo.to_be_bytes()), + ) + .unwrap(), + false, + ), + ) + .unwrap(); + + let p = p0 + p1; + + let p = p.to_encoded_point(false); + let (x, y) = match p.coordinates() { + Coordinates::Uncompressed { x, y } => (x, y), + _ => { + // This should be unreachable because we explicitly asked for the uncompressed + // encoding. + unreachable!() + } + }; + + // The following two unwraps should be safe because the array always has 32 bytes. The other + // four are definitely safe because the slicing guarantees its length to be the right one. + let x: [u8; 32] = x.as_slice().try_into().unwrap(); + let y: [u8; 32] = y.as_slice().try_into().unwrap(); + Ok(Secp256r1Point { + x: U256 { + hi: u128::from_be_bytes(x[0..16].try_into().unwrap()), + lo: u128::from_be_bytes(x[16..32].try_into().unwrap()), + }, + y: U256 { + hi: u128::from_be_bytes(y[0..16].try_into().unwrap()), + lo: u128::from_be_bytes(y[16..32].try_into().unwrap()), + }, + }) + } + + fn secp256r1_mul( + &mut self, + p: Secp256r1Point, + m: U256, + _remaining_gas: &mut u128, + ) -> SyscallResult { + // The inner unwrap should be unreachable because the iterator we provide has the expected + // number of bytes. The outer unwrap depends on the felt values, which should be valid since + // they'll be provided by secp256 syscalls. + let p = p256::ProjectivePoint::from_encoded_point( + &p256::EncodedPoint::from_affine_coordinates( + &GenericArray::from_exact_iter( + p.x.hi.to_be_bytes().into_iter().chain(p.x.lo.to_be_bytes()), + ) + .unwrap(), + &GenericArray::from_exact_iter( + p.y.hi.to_be_bytes().into_iter().chain(p.y.lo.to_be_bytes()), + ) + .unwrap(), + false, + ), + ) + .unwrap(); + let m: p256::Scalar = p256::elliptic_curve::ScalarPrimitive::from_slice(&{ + let mut buf = [0u8; 32]; + buf[0..16].copy_from_slice(&m.hi.to_be_bytes()); + buf[16..32].copy_from_slice(&m.lo.to_be_bytes()); + buf + }) + .map_err(|_| { + vec![Felt::from_bytes_be( + b"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0invalid scalar", + )] + })? + .into(); + + let p = p * m; + + let p = p.to_encoded_point(false); + let (x, y) = match p.coordinates() { + Coordinates::Uncompressed { x, y } => (x, y), + _ => { + // This should be unreachable because we explicitly asked for the uncompressed + // encoding. + unreachable!() + } + }; + + // The following two unwraps should be safe because the array always has 32 bytes. The other + // four are definitely safe because the slicing guarantees its length to be the right one. + let x: [u8; 32] = x.as_slice().try_into().unwrap(); + let y: [u8; 32] = y.as_slice().try_into().unwrap(); + Ok(Secp256r1Point { + x: U256 { + hi: u128::from_be_bytes(x[0..16].try_into().unwrap()), + lo: u128::from_be_bytes(x[16..32].try_into().unwrap()), + }, + y: U256 { + hi: u128::from_be_bytes(y[0..16].try_into().unwrap()), + lo: u128::from_be_bytes(y[16..32].try_into().unwrap()), + }, + }) + } + + fn secp256r1_get_point_from_x( + &mut self, + x: U256, + y_parity: bool, + _remaining_gas: &mut u128, + ) -> SyscallResult> { + let point = p256::ProjectivePoint::from_encoded_point( + &p256::EncodedPoint::from_bytes( + p256::CompressedPoint::from_exact_iter( + once(0x02 | y_parity as u8) + .chain(x.hi.to_be_bytes()) + .chain(x.lo.to_be_bytes()), + ) + .unwrap(), + ) + .unwrap(), + ); + + if bool::from(point.is_some()) { + let p = point.unwrap(); + + let p = p.to_encoded_point(false); + let y = match p.coordinates() { + Coordinates::Uncompressed { y, .. } => y, + _ => unreachable!(), + }; + + let y: [u8; 32] = y.as_slice().try_into().unwrap(); + Ok(Some(Secp256r1Point { + x, + y: U256 { + hi: u128::from_be_bytes(y[0..16].try_into().unwrap()), + lo: u128::from_be_bytes(y[16..32].try_into().unwrap()), + }, + })) + } else { + Ok(None) + } + } + + fn secp256r1_get_xy( + &mut self, + p: Secp256r1Point, + _remaining_gas: &mut u128, + ) -> SyscallResult<(U256, U256)> { + Ok((p.x, p.y)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_compilation() -> TestCompilation { + TestCompilation { + named_tests: vec![ + ( + String::from("test1"), + TestConfig { + available_gas: None, + expectation: TestExpectation::Success, + ignored: false, + }, + ), + ( + String::from("test2"), + TestConfig { + available_gas: None, + expectation: TestExpectation::Success, + ignored: true, + }, + ), + ( + String::from("test3"), + TestConfig { + available_gas: None, + expectation: TestExpectation::Success, + ignored: false, + }, + ), + ], + sierra_program: Program { + type_declarations: vec![], + libfunc_declarations: vec![], + statements: vec![], + funcs: vec![], + }, + statements_functions: Default::default(), + contracts_info: Default::default(), + function_set_costs: Default::default(), + } + } + + fn assert_named_test(lhs: &(String, TestConfig), rhs: &(String, TestConfig)) -> bool { + lhs.0 == rhs.0 + && lhs.1.available_gas == rhs.1.available_gas + && lhs.1.expectation == rhs.1.expectation + && lhs.1.ignored == rhs.1.ignored + } + + #[test] + fn test_filter_test_cases() { + let compiled = test_compilation(); + + let (filtered, filtered_out) = + filter_test_cases(compiled.clone(), false, false, String::from("test")); + + // Nothing should be filtered out. + assert_eq!(filtered_out, 0); + assert!(filtered + .named_tests + .iter() + .enumerate() + .all(|(i, x)| assert_named_test(x, &compiled.named_tests[i]))); + } + + #[test] + fn test_filter_test_cases_include_ignored() { + let compiled = test_compilation(); + + let (filtered, filtered_out) = + filter_test_cases(compiled.clone(), true, false, String::from("test")); + + // All tests should be included, even the ignored ones. + let expected = compiled + .named_tests + .into_iter() + .map(|mut x| { + x.1.ignored = false; + x + }) + .collect_vec(); + + assert_eq!(filtered_out, 0); + assert!(filtered + .named_tests + .iter() + .enumerate() + .all(|(i, x)| assert_named_test(x, &expected[i]))); + } + + #[test] + fn test_filter_test_cases_ignored() { + let compiled = test_compilation(); + + let (filtered, filtered_out) = + filter_test_cases(compiled.clone(), false, true, String::from("test")); + + // Only the ignored tests should be included. + let expected = compiled + .named_tests + .into_iter() + .filter(|x| x.1.ignored) + .map(|mut x| { + x.1.ignored = false; + x + }) + .collect_vec(); + + assert_eq!(filtered_out, 2); + assert!(filtered + .named_tests + .iter() + .enumerate() + .all(|(i, x)| assert_named_test(x, &expected[i]))); + } + + #[test] + fn test_filter_test_cases_include_ignored_and_ignored() { + let compiled = test_compilation(); + + let (filtered, filtered_out) = + filter_test_cases(compiled.clone(), true, true, String::from("test")); + + // All tests should be included, even the ignored ones. + let expected = compiled + .named_tests + .into_iter() + .map(|mut x| { + x.1.ignored = false; + x + }) + .collect_vec(); + + assert_eq!(filtered_out, 0); + assert!(filtered + .named_tests + .iter() + .enumerate() + .all(|(i, x)| assert_named_test(x, &expected[i]))); + } + + #[test] + fn test_secp256k1_get_xy() { + let p = Secp256k1Point { + x: U256 { + hi: 331229800296699308591929724809569456681, + lo: 240848751772479376198639683648735950585, + }, + y: U256 { + hi: 75181762170223969696219813306313470806, + lo: 134255467439736302886468555755295925874, + }, + }; + + let mut test_syscall_handler = TestSyscallHandler {}; + + assert_eq!( + test_syscall_handler.secp256k1_get_xy(p, &mut 10).unwrap(), + ( + U256 { + hi: 331229800296699308591929724809569456681, + lo: 240848751772479376198639683648735950585, + }, + U256 { + hi: 75181762170223969696219813306313470806, + lo: 134255467439736302886468555755295925874, + } + ) + ) + } + + #[test] + fn test_secp256k1_secp256k1_new() { + let mut test_syscall_handler = TestSyscallHandler {}; + + let x = U256 { + hi: 97179038819393695679, + lo: 330631467365974629050427735731901850225, + }; + let y = U256 { + hi: 26163136114030451075775058782541084873, + lo: 68974579539311638391577168388077592842, + }; + + assert_eq!( + test_syscall_handler.secp256k1_new(x, y, &mut 10).unwrap(), + Some(Secp256k1Point { x, y }) + ); + } + + #[test] + fn test_secp256k1_secp256k1_new_none() { + let mut test_syscall_handler = TestSyscallHandler {}; + + let x = U256 { + hi: 97179038819393695679, + lo: 330631467365974629050427735731901850225, + }; + let y = U256 { hi: 0, lo: 0 }; + + assert!(test_syscall_handler + .secp256k1_new(x, y, &mut 10) + .unwrap() + .is_none()); + } + + #[test] + fn test_secp256k1_ssecp256k1_add() { + let mut test_syscall_handler = TestSyscallHandler {}; + + let p1 = Secp256k1Point { + x: U256 { + hi: 161825202758953104525843685720298294023, + lo: 3468390537006497937951914270391801752, + }, + y: U256 { + hi: 96009999919712310848645357523629574312, + lo: 336417762351022071123394393598455764152, + }, + }; + + let p2 = p1; + + // 2 * P1 + let p3 = test_syscall_handler.secp256k1_add(p1, p2, &mut 10).unwrap(); + + let p1_double = Secp256k1Point { + x: U256 { + hi: 263210499965038831386353541518668627160, + lo: 122909745026270932982812610085084241637, + }, + y: U256 { + hi: 35730324229579385338853513728577301230, + lo: 329597642124196932058042157271922763050, + }, + }; + assert_eq!(p3, p1_double); + assert_eq!( + test_syscall_handler + .secp256k1_mul(p1, U256 { hi: 0, lo: 2 }, &mut 10) + .unwrap(), + p1_double + ); + + // 3 * P1 + let three_p1 = Secp256k1Point { + x: U256 { + hi: 331229800296699308591929724809569456681, + lo: 240848751772479376198639683648735950585, + }, + y: U256 { + hi: 75181762170223969696219813306313470806, + lo: 134255467439736302886468555755295925874, + }, + }; + assert_eq!( + test_syscall_handler.secp256k1_add(p1, p3, &mut 10).unwrap(), + three_p1 + ); + assert_eq!( + test_syscall_handler + .secp256k1_mul(p1, U256 { hi: 0, lo: 3 }, &mut 10) + .unwrap(), + three_p1 + ); + } + + #[test] + fn test_secp256k1_get_point_from_x_false_yparity() { + let mut test_syscall_handler = TestSyscallHandler {}; + + assert_eq!( + test_syscall_handler + .secp256k1_get_point_from_x( + U256 { + hi: 97179038819393695679, + lo: 330631467365974629050427735731901850225, + }, + false, + &mut 10 + ) + .unwrap() + .unwrap(), + Secp256k1Point { + x: U256 { + hi: 97179038819393695679, + lo: 330631467365974629050427735731901850225, + }, + y: U256 { + hi: 26163136114030451075775058782541084873, + lo: 68974579539311638391577168388077592842 + }, + } + ); + } + + #[test] + fn test_secp256k1_get_point_from_x_true_yparity() { + let mut test_syscall_handler = TestSyscallHandler {}; + + assert_eq!( + test_syscall_handler + .secp256k1_get_point_from_x( + U256 { + hi: 97179038819393695679, + lo: 330631467365974629050427735731901850225, + }, + true, + &mut 10 + ) + .unwrap() + .unwrap(), + Secp256k1Point { + x: U256 { + hi: 97179038819393695679, + lo: 330631467365974629050427735731901850225, + }, + y: U256 { + hi: 314119230806908012387599548649227126582, + lo: 271307787381626825071797439039395650341 + }, + } + ); + } + + #[test] + fn test_secp256k1_get_point_from_x_none() { + let mut test_syscall_handler = TestSyscallHandler {}; + + assert!(test_syscall_handler + .secp256k1_get_point_from_x(U256 { hi: 0, lo: 0 }, true, &mut 10) + .unwrap() + .is_none()); + } + + #[test] + fn test_secp256r1_new() { + let mut test_syscall_handler = TestSyscallHandler {}; + + let x = U256 { + hi: 97179038819393695679, + lo: 330631467365974629050427735731901850225, + }; + let y = U256 { + hi: 118910939004298029402109603132816090461, + lo: 111045440647474106186537215379882575585, + }; + + assert_eq!( + test_syscall_handler + .secp256r1_new(x, y, &mut 10) + .unwrap() + .unwrap(), + Secp256r1Point { x, y } + ); + } + + #[test] + fn test_secp256r1_new_none() { + let mut test_syscall_handler = TestSyscallHandler {}; + + let x = U256 { hi: 0, lo: 0 }; + let y = U256 { hi: 0, lo: 0 }; + + assert!(test_syscall_handler + .secp256r1_new(x, y, &mut 10) + .unwrap() + .is_none()); + } + + #[test] + fn test_secp256r1_add() { + let mut test_syscall_handler = TestSyscallHandler {}; + + let p1 = Secp256r1Point { + x: U256 { + hi: 97179038819393695679, + lo: 330631467365974629050427735731901850225, + }, + y: U256 { + hi: 118910939004298029402109603132816090461, + lo: 111045440647474106186537215379882575585, + }, + }; + + let p2 = p1; + + // 2 * P1 + let p3 = test_syscall_handler.secp256r1_add(p1, p2, &mut 10).unwrap(); + + let p1_double = Secp256r1Point { + x: U256 { + hi: 280079427190737520201067412903899817878, + lo: 309339945874468445579793098896656960879, + }, + y: U256 { + hi: 84249534056490759701994051847937833933, + lo: 231570843221643745062297421862629788481, + }, + }; + assert_eq!(p3, p1_double); + assert_eq!( + test_syscall_handler + .secp256r1_mul(p1, U256 { hi: 0, lo: 2 }, &mut 10) + .unwrap(), + p1_double + ); + + // 3 * P1 + let three_p1 = Secp256r1Point { + x: U256 { + hi: 23850518908906170876551962912581992002, + lo: 195259625777021303662291420857740525307, + }, + y: U256 { + hi: 178681203065513270100417145499857169664, + lo: 282344931843342117515389970197013120959, + }, + }; + assert_eq!( + test_syscall_handler.secp256r1_add(p1, p3, &mut 10).unwrap(), + three_p1 + ); + assert_eq!( + test_syscall_handler + .secp256r1_mul(p1, U256 { hi: 0, lo: 3 }, &mut 10) + .unwrap(), + three_p1 + ); + } + + #[test] + fn test_secp256r1_get_point_from_x_true_yparity() { + let mut test_syscall_handler = TestSyscallHandler {}; + + let x = U256 { + hi: 97179038819393695679, + lo: 330631467365974629050427735731901850225, + }; + + let y = U256 { + hi: 118910939004298029402109603132816090461, + lo: 111045440647474106186537215379882575585, + }; + + assert_eq!( + test_syscall_handler + .secp256r1_get_point_from_x(x, true, &mut 10) + .unwrap() + .unwrap(), + Secp256r1Point { x, y } + ); + } + + #[test] + fn test_secp256r1_get_point_from_x_false_yparity() { + let mut test_syscall_handler = TestSyscallHandler {}; + + let x = U256 { + hi: 97179038819393695679, + lo: 330631467365974629050427735731901850225, + }; + + let y = U256 { + hi: 221371427837412271565447410779117722274, + lo: 229236926352692519791101729645429586206, + }; + + assert_eq!( + test_syscall_handler + .secp256r1_get_point_from_x(x, false, &mut 10) + .unwrap() + .unwrap(), + Secp256r1Point { x, y } + ); + } + + #[test] + fn test_secp256r1_get_point_from_x_none() { + let mut test_syscall_handler = TestSyscallHandler {}; + + let x = U256 { hi: 0, lo: 10 }; + + assert!(test_syscall_handler + .secp256r1_get_point_from_x(x, true, &mut 10) + .unwrap() + .is_none()); + } + + #[test] + fn test_secp256r1_get_xy() { + let p = Secp256r1Point { + x: U256 { + hi: 97179038819393695679, + lo: 330631467365974629050427735731901850225, + }, + y: U256 { + hi: 221371427837412271565447410779117722274, + lo: 229236926352692519791101729645429586206, + }, + }; + + let mut test_syscall_handler = TestSyscallHandler {}; + + assert_eq!( + test_syscall_handler.secp256r1_get_xy(p, &mut 10).unwrap(), + ( + U256 { + hi: 97179038819393695679, + lo: 330631467365974629050427735731901850225, + }, + U256 { + hi: 221371427837412271565447410779117722274, + lo: 229236926352692519791101729645429586206, + } + ) + ) + } +}