From f5a567f01a730dce04d09a945e47d7f849dd082b Mon Sep 17 00:00:00 2001 From: Dylan Martin Date: Sat, 16 Nov 2024 01:02:43 +0100 Subject: [PATCH] feat(flags): add support for matching static cohort membership (#25942) --- rust/feature-flags/src/api.rs | 5 + rust/feature-flags/src/flag_matching.rs | 649 +++++++++++++++++++--- rust/feature-flags/src/request_handler.rs | 1 + rust/feature-flags/src/test_utils.rs | 62 ++- 4 files changed, 637 insertions(+), 80 deletions(-) diff --git a/rust/feature-flags/src/api.rs b/rust/feature-flags/src/api.rs index 9d6b649719bd2..be21c1c37f550 100644 --- a/rust/feature-flags/src/api.rs +++ b/rust/feature-flags/src/api.rs @@ -108,6 +108,8 @@ pub enum FlagError { CohortFiltersParsingError, #[error("Cohort dependency cycle")] CohortDependencyCycle(String), + #[error("Person not found")] + PersonNotFound, } impl IntoResponse for FlagError { @@ -212,6 +214,9 @@ impl IntoResponse for FlagError { tracing::error!("Cohort dependency cycle: {}", msg); (StatusCode::BAD_REQUEST, msg) } + FlagError::PersonNotFound => { + (StatusCode::BAD_REQUEST, "Person not found. Please check your distinct_id and try again.".to_string()) + } } .into_response() } diff --git a/rust/feature-flags/src/flag_matching.rs b/rust/feature-flags/src/flag_matching.rs index 571fe9c84b40a..d9332fce4e495 100644 --- a/rust/feature-flags/src/flag_matching.rs +++ b/rust/feature-flags/src/flag_matching.rs @@ -15,7 +15,7 @@ use petgraph::algo::{is_cyclic_directed, toposort}; use petgraph::graph::DiGraph; use serde_json::Value; use sha1::{Digest, Sha1}; -use sqlx::{postgres::PgQueryResult, Acquire, FromRow}; +use sqlx::{postgres::PgQueryResult, Acquire, FromRow, Row}; use std::fmt::Write; use std::sync::Arc; use std::{ @@ -26,6 +26,7 @@ use tokio::time::{sleep, timeout}; use tracing::{error, info}; pub type TeamId = i32; +pub type PersonId = i32; pub type GroupTypeIndex = i32; pub type PostgresReader = Arc; pub type PostgresWriter = Arc; @@ -176,6 +177,7 @@ impl GroupTypeMappingCache { /// to fetch the properties from the DB each time. #[derive(Clone, Default, Debug)] pub struct PropertiesCache { + person_id: Option, person_properties: Option>, group_properties: HashMap>, } @@ -217,9 +219,18 @@ impl FeatureFlagMatcher { } } - /// Evaluate feature flags for a given distinct_id - /// - Returns a map of feature flag keys to their values - /// - If an error occurs while evaluating a flag, it will be logged and the flag will be omitted from the result + /// Evaluates all feature flags for the current matcher context. + /// + /// ## Arguments + /// + /// * `feature_flags` - The list of feature flags to evaluate. + /// * `person_property_overrides` - Any overrides for person properties. + /// * `group_property_overrides` - Any overrides for group properties. + /// * `hash_key_override` - Optional hash key overrides for experience continuity. + /// + /// ## Returns + /// + /// * `FlagsResponse` - The result containing flag evaluations and any errors. pub async fn evaluate_all_feature_flags( &mut self, feature_flags: FeatureFlagList, @@ -746,22 +757,29 @@ impl FeatureFlagMatcher { .partition(|prop| prop.is_cohort()); // Get the properties we need to check for in this condition match from the flag + any overrides - let target_properties = self + let person_or_group_properties = self .get_properties_to_check(feature_flag, property_overrides, &non_cohort_filters) .await?; // Evaluate non-cohort filters first, since they're cheaper to evaluate and we can return early if they don't match - if !all_properties_match(&non_cohort_filters, &target_properties) { + if !all_properties_match(&non_cohort_filters, &person_or_group_properties) { return Ok((false, FeatureFlagMatchReason::NoConditionMatch)); } // Evaluate cohort filters, if any. - if !cohort_filters.is_empty() - && !self - .evaluate_cohort_filters(&cohort_filters, &target_properties) + if !cohort_filters.is_empty() { + // Get the person ID for the current distinct ID – this value should be cached at this point, but as a fallback we fetch from the database + let person_id = self.get_person_id().await?; + if !self + .evaluate_cohort_filters( + &cohort_filters, + &person_or_group_properties, + person_id, + ) .await? - { - return Ok((false, FeatureFlagMatchReason::NoConditionMatch)); + { + return Ok((false, FeatureFlagMatchReason::NoConditionMatch)); + } } } @@ -809,6 +827,31 @@ impl FeatureFlagMatcher { } } + /// Retrieves the `PersonId` from the properties cache. + /// If the cache does not contain a `PersonId`, it fetches it from the database + /// and updates the cache accordingly. + async fn get_person_id(&mut self) -> Result { + match self.properties_cache.person_id { + Some(id) => Ok(id), + None => { + let id = self.get_person_id_from_db().await?; + self.properties_cache.person_id = Some(id); + Ok(id) + } + } + } + + /// Fetches the `PersonId` from the database based on the current `distinct_id` and `team_id`. + /// This method is called when the `PersonId` is not present in the properties cache. + async fn get_person_id_from_db(&mut self) -> Result { + let postgres_reader = self.postgres_reader.clone(); + let distinct_id = self.distinct_id.clone(); + let team_id = self.team_id; + fetch_person_properties_from_db(postgres_reader, distinct_id, team_id) + .await + .map(|(_, person_id)| person_id) + } + /// Get person properties from cache or database. /// /// This function attempts to retrieve person properties either from a cache or directly from the database. @@ -836,26 +879,45 @@ impl FeatureFlagMatcher { &self, cohort_property_filters: &[PropertyFilter], target_properties: &HashMap, + person_id: PersonId, ) -> Result { // At the start of the request, fetch all of the cohorts for the team from the cache - // This method also caches the cohorts in memory for the duration of the application, so we don't need to fetch from - // the database again until we restart the application. + // This method also caches any cohorts for a given team in memory for the duration of the application, so we don't need to fetch from + // the database again until we restart the application. See the CohortCacheManager for more details. let cohorts = self.cohort_cache.get_cohorts_for_team(self.team_id).await?; - // Store cohort match results in a HashMap to avoid re-evaluating the same cohort multiple times, - // since the same cohort could appear in multiple property filters. This is especially important - // because evaluating a cohort requires evaluating all of its dependencies, which can be expensive. + // Split the cohorts into static and dynamic, since the dynamic ones have property filters + // and we need to evaluate them based on the target properties, whereas the static ones are + // purely based on person properties and are membership-based. + let (static_cohorts, dynamic_cohorts): (Vec<_>, Vec<_>) = + cohorts.iter().partition(|c| c.is_static); + + // Store all cohort match results in a HashMap to avoid re-evaluating the same cohort multiple times, + // since the same cohort could appear in multiple property filters. let mut cohort_matches = HashMap::new(); - for filter in cohort_property_filters { - let cohort_id = filter - .get_cohort_id() - .ok_or(FlagError::CohortFiltersParsingError)?; - let match_result = - evaluate_cohort_dependencies(cohort_id, target_properties, cohorts.clone())?; - cohort_matches.insert(cohort_id, match_result); + + if !static_cohorts.is_empty() { + let results = evaluate_static_cohorts( + self.postgres_reader.clone(), + person_id, + static_cohorts.iter().map(|c| c.id).collect(), + ) + .await?; + cohort_matches.extend(results); } - // Apply cohort membership logic (IN|NOT_IN) + if !dynamic_cohorts.is_empty() { + for filter in cohort_property_filters { + let cohort_id = filter + .get_cohort_id() + .ok_or(FlagError::CohortFiltersParsingError)?; + let match_result = + evaluate_dynamic_cohorts(cohort_id, target_properties, cohorts.clone())?; + cohort_matches.insert(cohort_id, match_result); + } + } + + // Apply cohort membership logic (IN|NOT_IN) to the cohort match results apply_cohort_membership_logic(cohort_property_filters, &cohort_matches) } @@ -971,11 +1033,12 @@ impl FeatureFlagMatcher { let postgres_reader = self.postgres_reader.clone(); let distinct_id = self.distinct_id.clone(); let team_id = self.team_id; - let db_properties = + let (db_properties, person_id) = fetch_person_properties_from_db(postgres_reader, distinct_id, team_id).await?; - // once the properties are fetched, cache them so we don't need to fetch again in a given request + // once the properties and person ID are fetched, cache them so we don't need to fetch again in a given request self.properties_cache.person_properties = Some(db_properties.clone()); + self.properties_cache.person_id = Some(person_id); Ok(db_properties) } @@ -1102,10 +1165,49 @@ impl FeatureFlagMatcher { } } -/// Evaluates a single cohort and its dependencies. +/// Evaluate static cohort filters by checking if the person is in each cohort. +async fn evaluate_static_cohorts( + postgres_reader: PostgresReader, + person_id: i32, // Change this parameter from distinct_id to person_id + cohort_ids: Vec, +) -> Result, FlagError> { + let mut conn = postgres_reader.get_connection().await?; + + let query = r#" + WITH cohort_membership AS ( + SELECT c.cohort_id, + CASE WHEN pc.cohort_id IS NOT NULL THEN true ELSE false END AS is_member + FROM unnest($1::integer[]) AS c(cohort_id) + LEFT JOIN posthog_cohortpeople AS pc + ON pc.person_id = $2 + AND pc.cohort_id = c.cohort_id + ) + SELECT cohort_id, is_member + FROM cohort_membership + "#; + + let rows = sqlx::query(query) + .bind(&cohort_ids) + .bind(person_id) // Bind person_id directly + .fetch_all(&mut *conn) + .await?; + + let result = rows + .into_iter() + .map(|row| { + let cohort_id: CohortId = row.get("cohort_id"); + let is_member: bool = row.get("is_member"); + (cohort_id, is_member) + }) + .collect(); + + Ok(result) +} + +/// Evaluates a dynamic cohort and its dependencies. /// This uses a topological sort to evaluate dependencies first, which is necessary /// because a cohort can depend on another cohort, and we need to respect the dependency order. -fn evaluate_cohort_dependencies( +fn evaluate_dynamic_cohorts( initial_cohort_id: CohortId, target_properties: &HashMap, cohorts: Vec, @@ -1221,6 +1323,16 @@ fn build_cohort_dependency_graph( let mut graph = DiGraph::new(); let mut node_map = HashMap::new(); let mut queue = VecDeque::new(); + + let initial_cohort = cohorts + .iter() + .find(|c| c.id == initial_cohort_id) + .ok_or(FlagError::CohortNotFound(initial_cohort_id.to_string()))?; + + if initial_cohort.is_static { + return Ok(graph); + } + // This implements a breadth-first search (BFS) traversal to build a directed graph of cohort dependencies. // Starting from the initial cohort, we: // 1. Add each cohort as a node in the graph @@ -1283,32 +1395,52 @@ async fn fetch_and_locally_cache_all_properties( let query = r#" SELECT - (SELECT "posthog_person"."properties" - FROM "posthog_person" - INNER JOIN "posthog_persondistinctid" - ON ("posthog_person"."id" = "posthog_persondistinctid"."person_id") - WHERE ("posthog_persondistinctid"."distinct_id" = $1 - AND "posthog_persondistinctid"."team_id" = $2 - AND "posthog_person"."team_id" = $2) - LIMIT 1) as person_properties, - - (SELECT json_object_agg("posthog_group"."group_type_index", "posthog_group"."group_properties") - FROM "posthog_group" - WHERE ("posthog_group"."team_id" = $2 - AND "posthog_group"."group_type_index" = ANY($3))) as group_properties + person.person_id, + person.person_properties, + group_properties.group_properties + FROM ( + SELECT + "posthog_person"."id" AS person_id, + "posthog_person"."properties" AS person_properties + FROM "posthog_person" + INNER JOIN "posthog_persondistinctid" + ON "posthog_person"."id" = "posthog_persondistinctid"."person_id" + WHERE + "posthog_persondistinctid"."distinct_id" = $1 + AND "posthog_persondistinctid"."team_id" = $2 + AND "posthog_person"."team_id" = $2 + LIMIT 1 + ) AS person, + ( + SELECT + json_object_agg( + "posthog_group"."group_type_index", + "posthog_group"."group_properties" + ) AS group_properties + FROM "posthog_group" + WHERE + "posthog_group"."team_id" = $2 + AND "posthog_group"."group_type_index" = ANY($3) + ) AS group_properties "#; let group_type_indexes_vec: Vec = group_type_indexes.iter().cloned().collect(); - let row: (Option, Option) = sqlx::query_as(query) + let row: (Option, Option, Option) = sqlx::query_as(query) .bind(&distinct_id) .bind(team_id) .bind(&group_type_indexes_vec) .fetch_optional(&mut *conn) .await? - .unwrap_or((None, None)); + .unwrap_or((None, None, None)); - if let Some(person_props) = row.0 { + let (person_id, person_props, group_props) = row; + + if let Some(person_id) = person_id { + properties_cache.person_id = Some(person_id); + } + + if let Some(person_props) = person_props { properties_cache.person_properties = Some( person_props .as_object() @@ -1319,7 +1451,7 @@ async fn fetch_and_locally_cache_all_properties( ); } - if let Some(group_props) = row.1 { + if let Some(group_props) = group_props { let group_props_map: HashMap> = group_props .as_object() .unwrap_or(&serde_json::Map::new()) @@ -1342,7 +1474,7 @@ async fn fetch_and_locally_cache_all_properties( Ok(()) } -/// Fetch person properties from the database for a given distinct ID and team ID. +/// Fetch person properties and person ID from the database for a given distinct ID and team ID. /// /// This function constructs and executes a SQL query to fetch the person properties for a specified distinct ID and team ID. /// It returns the fetched properties as a HashMap. @@ -1350,31 +1482,37 @@ async fn fetch_person_properties_from_db( postgres_reader: PostgresReader, distinct_id: String, team_id: TeamId, -) -> Result, FlagError> { +) -> Result<(HashMap, i32), FlagError> { let mut conn = postgres_reader.as_ref().get_connection().await?; let query = r#" - SELECT "posthog_person"."properties" as person_properties - FROM "posthog_person" - INNER JOIN "posthog_persondistinctid" ON ("posthog_person"."id" = "posthog_persondistinctid"."person_id") - WHERE ("posthog_persondistinctid"."distinct_id" = $1 - AND "posthog_persondistinctid"."team_id" = $2 - AND "posthog_person"."team_id" = $2) - LIMIT 1 - "#; - - let row: Option = sqlx::query_scalar(query) + SELECT "posthog_person"."id" as person_id, "posthog_person"."properties" as person_properties + FROM "posthog_person" + INNER JOIN "posthog_persondistinctid" ON ("posthog_person"."id" = "posthog_persondistinctid"."person_id") + WHERE ("posthog_persondistinctid"."distinct_id" = $1 + AND "posthog_persondistinctid"."team_id" = $2 + AND "posthog_person"."team_id" = $2) + LIMIT 1 + "#; + + let row: Option<(i32, Value)> = sqlx::query_as(query) .bind(&distinct_id) .bind(team_id) .fetch_optional(&mut *conn) .await?; - Ok(row - .and_then(|v| v.as_object().cloned()) - .unwrap_or_default() - .into_iter() - .map(|(k, v)| (k, v.clone())) - .collect()) + match row { + Some((person_id, person_props)) => { + let properties_map = person_props + .as_object() + .unwrap_or(&serde_json::Map::new()) + .iter() + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + Ok((properties_map, person_id)) + } + None => Err(FlagError::PersonNotFound), + } } /// Fetch group properties from the database for a given team ID and group type index. @@ -1436,11 +1574,11 @@ fn locally_computable_property_overrides( /// Check if all properties match the given filters fn all_properties_match( flag_condition_properties: &[PropertyFilter], - target_properties: &HashMap, + matching_property_values: &HashMap, ) -> bool { flag_condition_properties .iter() - .all(|property| match_property(property, target_properties, false).unwrap_or(false)) + .all(|property| match_property(property, matching_property_values, false).unwrap_or(false)) } async fn get_feature_flag_hash_key_overrides( @@ -1663,8 +1801,9 @@ mod tests { OperatorType, }, test_utils::{ - insert_cohort_for_team_in_pg, insert_flag_for_team_in_pg, insert_new_team_in_pg, - insert_person_for_team_in_pg, setup_pg_reader_client, setup_pg_writer_client, + add_person_to_cohort, get_person_id_by_distinct_id, insert_cohort_for_team_in_pg, + insert_flag_for_team_in_pg, insert_new_team_in_pg, insert_person_for_team_in_pg, + setup_pg_reader_client, setup_pg_writer_client, }, }; @@ -1750,6 +1889,7 @@ mod tests { )) .unwrap(); + // Matcher for a matching distinct_id let mut matcher = FeatureFlagMatcher::new( distinct_id.clone(), team.id, @@ -1763,6 +1903,7 @@ mod tests { assert!(match_result.matches); assert_eq!(match_result.variant, None); + // Matcher for a non-matching distinct_id let mut matcher = FeatureFlagMatcher::new( not_matching_distinct_id.clone(), team.id, @@ -1776,6 +1917,7 @@ mod tests { assert!(!match_result.matches); assert_eq!(match_result.variant, None); + // Matcher for a distinct_id that does not exist let mut matcher = FeatureFlagMatcher::new( "other_distinct_id".to_string(), team.id, @@ -1785,9 +1927,10 @@ mod tests { None, None, ); - let match_result = matcher.get_match(&flag, None, None).await.unwrap(); - assert!(!match_result.matches); - assert_eq!(match_result.variant, None); + let match_result = matcher.get_match(&flag, None, None).await; + + // Expecting an error for non-existent distinct_id + assert!(match_result.is_err()); } #[tokio::test] @@ -3106,6 +3249,19 @@ mod tests { .await .unwrap(); + insert_person_for_team_in_pg(postgres_reader.clone(), team.id, "lil_id".to_string(), None) + .await + .unwrap(); + + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + "another_id".to_string(), + None, + ) + .await + .unwrap(); + let mut matcher_test_id = FeatureFlagMatcher::new( "test_id".to_string(), team.id, @@ -3265,6 +3421,19 @@ mod tests { .await .unwrap(); + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + "another_id".to_string(), + None, + ) + .await + .unwrap(); + + insert_person_for_team_in_pg(postgres_reader.clone(), team.id, "lil_id".to_string(), None) + .await + .unwrap(); + let flag = create_test_flag( Some(1), Some(team.id), @@ -3852,6 +4021,344 @@ mod tests { assert!(!result.matches); } + #[tokio::test] + async fn test_static_cohort_matching_user_in_cohort() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a static cohort + let cohort = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + Some("Static Cohort".to_string()), + json!({}), // Static cohorts don't have property filters + true, // is_static = true + ) + .await + .unwrap(); + + // Insert a person + let distinct_id = "static_user".to_string(); + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + distinct_id.clone(), + Some(json!({"email": "static@user.com"})), + ) + .await + .unwrap(); + + // Retrieve the person's ID + let person_id = + get_person_id_by_distinct_id(postgres_reader.clone(), team.id, &distinct_id) + .await + .unwrap(); + + // Associate the person with the static cohort + add_person_to_cohort(postgres_reader.clone(), person_id, cohort.id) + .await + .unwrap(); + + // Define a flag with an 'In' cohort filter + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(cohort.id), + operator: Some(OperatorType::In), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + distinct_id.clone(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + cohort_cache.clone(), + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + + assert!( + result.matches, + "User should match the static cohort and flag" + ); + } + + #[tokio::test] + async fn test_static_cohort_matching_user_not_in_cohort() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a static cohort + let cohort = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + Some("Another Static Cohort".to_string()), + json!({}), // Static cohorts don't have property filters + true, + ) + .await + .unwrap(); + + // Insert a person + let distinct_id = "non_static_user".to_string(); + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + distinct_id.clone(), + Some(json!({"email": "nonstatic@user.com"})), + ) + .await + .unwrap(); + + // Note: Do NOT associate the person with the static cohort + + // Define a flag with an 'In' cohort filter + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(cohort.id), + operator: Some(OperatorType::In), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + distinct_id.clone(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + cohort_cache.clone(), + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + + assert!( + !result.matches, + "User should not match the static cohort and flag" + ); + } + + #[tokio::test] + async fn test_static_cohort_not_in_matching_user_not_in_cohort() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a static cohort + let cohort = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + Some("Static Cohort NotIn".to_string()), + json!({}), // Static cohorts don't have property filters + true, // is_static = true + ) + .await + .unwrap(); + + // Insert a person + let distinct_id = "not_in_static_user".to_string(); + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + distinct_id.clone(), + Some(json!({"email": "notinstatic@user.com"})), + ) + .await + .unwrap(); + + // No association with the static cohort + + // Define a flag with a 'NotIn' cohort filter + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(cohort.id), + operator: Some(OperatorType::NotIn), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + distinct_id.clone(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + cohort_cache.clone(), + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + + assert!( + result.matches, + "User not in the static cohort should match the 'NotIn' flag" + ); + } + + #[tokio::test] + async fn test_static_cohort_not_in_matching_user_in_cohort() { + let postgres_reader = setup_pg_reader_client(None).await; + let postgres_writer = setup_pg_writer_client(None).await; + let cohort_cache = Arc::new(CohortCacheManager::new(postgres_reader.clone(), None, None)); + let team = insert_new_team_in_pg(postgres_reader.clone(), None) + .await + .unwrap(); + + // Insert a static cohort + let cohort = insert_cohort_for_team_in_pg( + postgres_reader.clone(), + team.id, + Some("Static Cohort NotIn User In".to_string()), + json!({}), // Static cohorts don't have property filters + true, // is_static = true + ) + .await + .unwrap(); + + // Insert a person + let distinct_id = "in_not_in_static_user".to_string(); + insert_person_for_team_in_pg( + postgres_reader.clone(), + team.id, + distinct_id.clone(), + Some(json!({"email": "innotinstatic@user.com"})), + ) + .await + .unwrap(); + + // Retrieve the person's ID + let person_id = + get_person_id_by_distinct_id(postgres_reader.clone(), team.id, &distinct_id) + .await + .unwrap(); + + // Associate the person with the static cohort + add_person_to_cohort(postgres_reader.clone(), person_id, cohort.id) + .await + .unwrap(); + + // Define a flag with a 'NotIn' cohort filter + let flag = create_test_flag( + None, + Some(team.id), + None, + None, + Some(FlagFilters { + groups: vec![FlagGroupType { + properties: Some(vec![PropertyFilter { + key: "id".to_string(), + value: json!(cohort.id), + operator: Some(OperatorType::NotIn), + prop_type: "cohort".to_string(), + group_type_index: None, + negation: Some(false), + }]), + rollout_percentage: Some(100.0), + variant: None, + }], + multivariate: None, + aggregation_group_type_index: None, + payloads: None, + super_groups: None, + }), + None, + None, + None, + ); + + let mut matcher = FeatureFlagMatcher::new( + distinct_id.clone(), + team.id, + postgres_reader.clone(), + postgres_writer.clone(), + cohort_cache.clone(), + None, + None, + ); + + let result = matcher.get_match(&flag, None, None).await.unwrap(); + + assert!( + !result.matches, + "User in the static cohort should not match the 'NotIn' flag" + ); + } + #[tokio::test] async fn test_set_feature_flag_hash_key_overrides_success() { let postgres_reader = setup_pg_reader_client(None).await; @@ -4095,7 +4602,6 @@ mod tests { .unwrap(); let distinct_id = "user4".to_string(); - // Insert person insert_person_for_team_in_pg( postgres_reader.clone(), team.id, @@ -4168,7 +4674,6 @@ mod tests { .unwrap(); let distinct_id = "user5".to_string(); - // Insert person insert_person_for_team_in_pg( postgres_reader.clone(), team.id, diff --git a/rust/feature-flags/src/request_handler.rs b/rust/feature-flags/src/request_handler.rs index 538c6845d2a02..5ef43896e641f 100644 --- a/rust/feature-flags/src/request_handler.rs +++ b/rust/feature-flags/src/request_handler.rs @@ -97,6 +97,7 @@ pub async fn process_request(context: RequestContext) -> Result, -) -> Result<(), Error> { +) -> Result { + // Changed return type to Result let payload = match properties { Some(value) => value, None => json!({ @@ -329,7 +330,7 @@ pub async fn insert_person_for_team_in_pg( let uuid = Uuid::now_v7(); let mut conn = client.get_connection().await?; - let res = sqlx::query( + let row = sqlx::query( r#" WITH inserted_person AS ( INSERT INTO posthog_person ( @@ -337,10 +338,11 @@ pub async fn insert_person_for_team_in_pg( properties_last_operation, team_id, is_user_id, is_identified, uuid, version ) VALUES ('2023-04-05', $1, '{}', '{}', $2, NULL, true, $3, 0) - RETURNING * + RETURNING id ) INSERT INTO posthog_persondistinctid (distinct_id, person_id, team_id, version) VALUES ($4, (SELECT id FROM inserted_person), $5, 0) + RETURNING person_id "#, ) .bind(&payload) @@ -348,12 +350,11 @@ pub async fn insert_person_for_team_in_pg( .bind(uuid) .bind(&distinct_id) .bind(team_id) - .execute(&mut *conn) + .fetch_one(&mut *conn) .await?; - assert_eq!(res.rows_affected(), 1); - - Ok(()) + let person_id: i32 = row.get::("person_id"); + Ok(person_id) } pub async fn insert_cohort_for_team_in_pg( @@ -410,3 +411,48 @@ pub async fn insert_cohort_for_team_in_pg( Ok(Cohort { id, ..cohort }) } + +pub async fn get_person_id_by_distinct_id( + client: Arc, + team_id: i32, + distinct_id: &str, +) -> Result { + let mut conn = client.get_connection().await?; + let row: (i32,) = sqlx::query_as( + r#"SELECT id FROM posthog_person + WHERE team_id = $1 AND id = ( + SELECT person_id FROM posthog_persondistinctid + WHERE team_id = $1 AND distinct_id = $2 + LIMIT 1 + ) + LIMIT 1"#, + ) + .bind(team_id) + .bind(distinct_id) + .fetch_one(&mut *conn) + .await + .map_err(|_| anyhow::anyhow!("Person not found"))?; + + Ok(row.0) +} + +pub async fn add_person_to_cohort( + client: Arc, + person_id: i32, + cohort_id: i32, +) -> Result<(), Error> { + let mut conn = client.get_connection().await?; + let res = sqlx::query( + r#"INSERT INTO posthog_cohortpeople (cohort_id, person_id) + VALUES ($1, $2) + ON CONFLICT DO NOTHING"#, + ) + .bind(cohort_id) + .bind(person_id) + .execute(&mut *conn) + .await?; + + assert!(res.rows_affected() > 0, "Failed to add person to cohort"); + + Ok(()) +}