Browse Source

Changed requirement for windows, added percentage splitting methods

master
Bianca Steffes 5 days ago
parent
commit
dbbdcd0078
  1. 1
      .gitignore
  2. 32
      main.py
  3. 2
      requirements.txt

1
.gitignore

@ -138,3 +138,4 @@ dmypy.json
# Cython debug symbols
cython_debug/
.idea

32
main.py

@ -1,3 +1,6 @@
import numpy as np
import pandas as pd
from pipeline import (
load_dataset,
filter_data,
@ -9,6 +12,10 @@ from pipeline import (
display_warnings_for_scenarios
)
year_str = 'Year'
month_str = 'Month'
user_str = 'user'
# === Configurable Parameters ===
DATA_PATH = './Datasets/ALLUSERS32_15MIN_WITHOUTTHREHOLD.xlsx'
OUTPUT_EXCEL_PATH = './working/evaluation_results.xlsx'
@ -27,6 +34,28 @@ predefined_validation_scenarios = {
"Scenario A": {"years_months": [(2019, [10, 11, 12])]}
}
def remove_covid_data(df):
df = df[~((df[year_str]==2020) & (df[month_str]>2))]
return df
def split_data_by_month_percentage(df, percentages):
train_p, valid_p, test_p = percentages
ids = df[[year_str, month_str]].drop_duplicates().sort_values([year_str, month_str])
tr, va, te = np.split(ids, [int((train_p/100) * len(ids)), int(((train_p + valid_p)/100) * len(ids))])
return df.merge(tr, on=[year_str, month_str], how='inner'), df.merge(va, on=[year_str, month_str], how='inner'), df.merge(te, on=[year_str, month_str], how='inner')
def split_data_by_userdata_percentage(df, percentages):
train_p, valid_p, test_p = percentages
tr, va, te = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
for user_id in df[user_str].unique():
user_data = df[df[user_str]==user_id].sort_values([year_str, month_str])
u_tr, u_va, u_te = np.split(user_data, [int((train_p/100)*len(user_data)), int(((train_p+valid_p)/100)*len(user_data))])
tr = pd.concat([tr, u_tr], ignore_index=True)
va = pd.concat([va, u_va], ignore_index=True)
te = pd.concat([te, u_te], ignore_index=True)
return tr, va, te
def main():
# print("=== Training Scenario Setup ===")
# display_warning_about_2020_data()
@ -38,6 +67,9 @@ def main():
# === Load and preprocess ===
df = load_dataset(DATA_PATH)
removed = remove_covid_data(df)
tr,val,te = split_data_by_userdata_percentage(df, (80,10,10))
tr_2, val_2, te_2 = split_data_by_month_percentage(df, (80, 10, 10))
ALLUSERS32_15MIN_WITHOUTTHREHOLD = False
if('ALLUSERS32_15MIN_WITHOUTTHREHOLD.xlsx' in DATA_PATH):

2
requirements.txt

@ -38,7 +38,7 @@ six==1.17.0
tensorboard==2.19.0
tensorboard-data-server==0.7.2
tensorflow==2.19.0
tensorflow-io-gcs-filesystem==0.37.1
tensorflow-io-gcs-filesystem==0.31.0
termcolor==3.1.0
threadpoolctl==3.6.0
typing_extensions==4.14.1

Loading…
Cancel
Save