Skip to content

Commit

Permalink
WIP: Added support for exogenous variables and improved plot function.
Browse files Browse the repository at this point in the history
  • Loading branch information
MMenchero committed Nov 10, 2023
1 parent 8895908 commit 266ab9a
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 41 deletions.
2 changes: 1 addition & 1 deletion R/date_conversion.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ date_conversion <- function(df){
}else if(cls == "Date"){
freq <- "D"

}else if(cls %in% c("POSIXct", "POSIXt")){
}else if(cls %in% c("POSIXt", "POSIXct", "POSIXlt")){
freq <- "H"

}else{
Expand Down
16 changes: 12 additions & 4 deletions R/timegpt_anomaly_detection.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
#' @param id_col Column that identifies each series.
#' @param time_col Column that identifies each timestep.
#' @param target_col Column that contains the target variable.
#' @param X_df A tsibble or a data frame with future exogenous variables.
#' @param level The confidence level (0-100) for the prediction interval used in anomaly detection. Default is 99.
#' @param clean_ex_first Clean exogenous signal before making the forecasts using TimeGPT.
#'
#' @return A tsibble or a data frame with the anomalies detected in the historical period.
#' @export
#'
timegpt_anomaly_detection <- function(df, freq=NULL, id_col=NULL, time_col="ds", target_col="y", X_df=NULL, level=c(99), clean_ex_first=TRUE){
timegpt_anomaly_detection <- function(df, freq=NULL, id_col=NULL, time_col="ds", target_col="y", level=c(99), clean_ex_first=TRUE){

# Validation ----
token <- get("NIXTLAR_TOKEN", envir = nixtlaR_env)
Expand All @@ -38,8 +37,17 @@ timegpt_anomaly_detection <- function(df, freq=NULL, id_col=NULL, time_col="ds",
clean_ex_first = clean_ex_first
)

# Add exogenous regressors here
# ----------------------------*
if(any(!(names(df) %in% c("unique_id", "ds", "y")))){
exogenous <- df |>
dplyr::select(-y)

x <- list(
columns = names(exogenous),
data = lapply(1:nrow(exogenous), function(i) as.list(exogenous[i,]))
)

timegpt_data[['x']] <- x
}

if(length(level) > 1){
message("Multiple levels are not allowed for anomaly detection. Will use the largest.")
Expand Down
20 changes: 18 additions & 2 deletions R/timegpt_cross_validation.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,24 @@ timegpt_cross_validation <- function(df, h=8, freq=NULL, id_col=NULL, time_col="
clean_ex_first = clean_ex_first
)

# Add exogenous regressors here
# ----------------------------*
if(!is.null(X_df)){
names(X_df)[which(names(X_df) == time_col)] <- "ds"
if(!is.null(id_col)){
names(X_df)[which(names(X_df) == id_col)] <- "unique_id"
}

exogenous <- df |>
dplyr::select(-y)

exogenous <- rbind(exogenous, X_df)

x <- list(
columns = names(exogenous),
data = lapply(1:nrow(exogenous), function(i) as.list(exogenous[i,]))
)

timegpt_data[['x']] <- x
}

if(!is.null(level)){
level <- as.list(level)
Expand Down
31 changes: 19 additions & 12 deletions R/timegpt_forecast.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,24 @@ timegpt_forecast <- function(df, h=8, freq=NULL, id_col=NULL, time_col="ds", tar
clean_ex_first = clean_ex_first
)

# if(!is.null(X_df)){
# names(X_df)[which(names(X_df) == time_col)] <- "ds"
# if(!is.null(id_col)){
# names(X_df)[which(names(X_df) == id_col)] <- "unique_id"
# }
# x <- list(
# columns = names(X_df),
# data = lapply(1:nrow(X_df), function(i) as.list(X_df[i,]))
# )
# timegpt_data[["x"]] <- x
# }
if(!is.null(X_df)){
names(X_df)[which(names(X_df) == time_col)] <- "ds"
if(!is.null(id_col)){
names(X_df)[which(names(X_df) == id_col)] <- "unique_id"
}

exogenous <- df |>
dplyr::select(-y)

exogenous <- rbind(exogenous, X_df)

x <- list(
columns = names(exogenous),
data = lapply(1:nrow(exogenous), function(i) as.list(exogenous[i,]))
)

timegpt_data[['x']] <- x
}

if(!is.null(level)){
level <- as.list(level)
Expand Down Expand Up @@ -126,7 +133,7 @@ timegpt_forecast <- function(df, h=8, freq=NULL, id_col=NULL, time_col="ds", tar

# Generate fitted values ----
if(add_history){
fitted <- timegpt_historic(df, freq=freq, id_col=id_col, time_col=time_col, target_col=target_col, X_df=X_df, level=level, finetune_steps=finetune_steps, clean_ex_first=clean_ex_first)
fitted <- timegpt_historic(df, freq=freq, id_col=id_col, time_col=time_col, target_col=target_col, level=level, finetune_steps=finetune_steps, clean_ex_first=clean_ex_first)
if(tsibble::is_tsibble(df)){
fcst <- dplyr::bind_rows(fitted, fcst)
}else{
Expand Down
16 changes: 12 additions & 4 deletions R/timegpt_historic.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
#' @param id_col Column that identifies each series.
#' @param time_col Column that identifies each timestep.
#' @param target_col Column that contains the target variable.
#' @param X_df A tsibble or a data frame with future exogenous variables.
#' @param level The confidence levels (0-100) for the prediction intervals.
#' @param finetune_steps Number of steps used to finetune TimeGPT in the new data.
#' @param clean_ex_first Clean exogenous signal before making the forecasts using TimeGPT.
#'
#' @return TimeGPT's forecast for the in-sample period.
#' @export
#'
timegpt_historic <- function(df, freq=NULL, id_col=NULL, time_col="ds", target_col="y", X_df=NULL, level=NULL, finetune_steps=0, clean_ex_first=TRUE){
timegpt_historic <- function(df, freq=NULL, id_col=NULL, time_col="ds", target_col="y", level=NULL, finetune_steps=0, clean_ex_first=TRUE){

# Validation ----
token <- get("NIXTLAR_TOKEN", envir = nixtlaR_env)
Expand All @@ -40,8 +39,17 @@ timegpt_historic <- function(df, freq=NULL, id_col=NULL, time_col="ds", target_c
clean_ex_first = clean_ex_first
)

# Add exogenous regressors here
# ----------------------------*
if(any(!(names(df) %in% c("unique_id", "ds", "y")))){
exogenous <- df |>
dplyr::select(-y)

x <- list(
columns = names(exogenous),
data = lapply(1:nrow(exogenous), function(i) as.list(exogenous[i,]))
)

timegpt_data[['x']] <- x
}

if(!is.null(level)){
level <- as.list(level)
Expand Down
58 changes: 46 additions & 12 deletions R/timegpt_plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
#'
timegpt_plot <- function(df, fcst=NULL, h=NULL, id_col=NULL, time_col="ds", target_col="y", unique_ids = NULL, max_insample_length=NULL, plot_anomalies=FALSE){

if(!tsibble::is_tsibble(df) & !is.data.frame(df)){
stop("Only tsibbles or data frames are allowed.")
}

# Select facets ----
nrow <- 4
ncol <- 2
Expand All @@ -25,14 +29,16 @@ timegpt_plot <- function(df, fcst=NULL, h=NULL, id_col=NULL, time_col="ds", targ
if(!is.null(id_col)){
names(df)[which(names(df) == id_col)] <- "unique_id"

ids <- unique(df$unique_id)
if(length(ids) == 2){ # reshape for better viz
nrow <- 2
ncol <- 1
}

## Select time series if there are more than 8 ----
if(length(unique(df$unique_id)) > 8){
if(length(ids) > 8){
if(!is.null(unique_ids)){
ids <- unique_ids[1:min(length(unique_ids), 8)]
if(length(ids) == 2){ # reshape for better viz
nrow = 2
ncol = 1
}
}else{
ids <- sample(unique(df$unique_id), size=8, replace=FALSE)
}
Expand All @@ -47,13 +53,27 @@ timegpt_plot <- function(df, fcst=NULL, h=NULL, id_col=NULL, time_col="ds", targ
}
}

# Check for cross validation output
cross_validation <- FALSE
if("cutoff" %in% names(fcst)){
cross_validation <- TRUE
if(plot_anomalies){
message("Can't plot anomalies and cross validation output at the same time. Setting plot_anomalies=FALSE")
plot_anomalies <- FALSE
# Convert dates if necessary ----
# ggplot2 requires ds to be Dates while TimeGPT's API requires them to be chr
cls <- class(df$ds)[1]
if(!(cls %in% c("Date", "POSIXt", "POSIXct", "POSIXlt"))){

if(tsibble::is_tsibble(df)){
df_list <- nixtlaR::date_conversion(df)
df <- df_list$df
freq <- df_list$freq
}else{
freq <- nixtlaR::infer_frequency(df)
}

if(is.null(freq)){
stop("Can't figure out the frequency of the data. Please convert time_col to Date or POSIXt.")
}

if(freq == "H"){
df$ds <- lubridate::ymd_hms(df$ds)
}else{
df$ds <- lubridate::ymd(df$ds)
}
}

Expand All @@ -74,6 +94,10 @@ timegpt_plot <- function(df, fcst=NULL, h=NULL, id_col=NULL, time_col="ds", targ

}else{
# Plot historical values and forecast ----
if(!tsibble::is_tsibble(fcst) & !is.data.frame(fcst)){
stop("fcst needs to be the output of timegpt_forecast, timegpt_historic, timegpt_anomaly_detection or timegpt_cross_validation.")
}

color_vals <- c("#B5838D", "steelblue")

# Rename forecast columns ----
Expand All @@ -83,6 +107,16 @@ timegpt_plot <- function(df, fcst=NULL, h=NULL, id_col=NULL, time_col="ds", targ
names(fcst)[which(names(fcst) == id_col)] <- "unique_id"
}

# Check for cross validation output
cross_validation <- FALSE
if("cutoff" %in% names(fcst)){
cross_validation <- TRUE
if(plot_anomalies){
message("Can't plot anomalies and cross validation output at the same time. Setting plot_anomalies=FALSE")
plot_anomalies <- FALSE
}
}

if(!is.null(max_insample_length)){
df <- df |>
dplyr::group_by(.data$unique_id) |>
Expand Down
3 changes: 0 additions & 3 deletions man/timegpt_anomaly_detection.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions man/timegpt_historic.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 266ab9a

Please sign in to comment.