2024-07-18

Author

Witek ten Hove

Besproken met Joost:

Setup and load data

from schedule_class import NewSchedule, generate_schedules, service_time_with_no_shows
import pickle
import random
import numpy as np
import xgboost as xgb
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold, GridSearchCV
from sklearn.metrics import mean_squared_error, make_scorer, accuracy_score
from sklearn.base import clone
from sklearn.model_selection import StratifiedKFold, train_test_split, GridSearchCV
from itertools import chain, combinations
class ScheduleData:
  def __init__(self, N: int, T: int, samples, labels):
    self.N = N
    self.T = T
    self.samples = samples
    self.labels = labels
  
  def describe_data(self):
    print(f'N = {self.N}', f'T = {self.T}', '\nSamples',self.samples.tail(10), '\nLabels', self.labels.tail(10), sep = "\n")
  
  def create_pair_list(self, n): # Create a set of randomly selected pairs of schedules
    S = list(range(len(self.samples)))
    Q = random.choices(S, k=n)
    P = []
    
    for q in Q:
        # Create a list of possible choices excluding t
        possible_choices = [s for s in S if s != q]
        
        # Choose a random element from the possible choices
        p = random.choice(possible_choices)
        
        # Add the chosen element to the result list
        P.append(p)
    
    samples_s1 = self.samples.iloc[Q, :].values.tolist()
    samples_s2 = self.samples.iloc[P, :].values.tolist()
    self.pair_list = list(zip(samples_s1, samples_s2))
    self.lables_s1 = self.labels.loc[Q, 'obj'].values.tolist()
    self.lables_s2 = self.labels.loc[P, 'obj'].values.tolist()
    self.lables_rank = [1 * (self.lables_s1[i] > self.lables_s2[i]) for i in range(len(self.lables_s1))]
    print(self.pair_list[:15], "\n", self.lables_s1[:15], "\n", self.lables_s2[:15], "\n", self.lables_rank[:15])
    
  def create_neighbors_list(self, n): # Create a set of pairs of schedules that are from the same neighborhood
    # Build a subset of random schedules with length n
    S = list(range(len(self.samples)))
    Q = random.choices(S, k=n)
    samples_sub = self.samples.iloc[Q, :].values.tolist()
    labels_sub = self.labels.iloc[Q, 7]
    self.neighbors_list = []
    
    # For each schedule in in the subset choose 2 random intervals i, j and swap 1 patient
    for s in samples_sub:
      i = random.choice(range(len(s)))  # Ensure i is a valid index in s
      j = [index for index, element in enumerate(s) if element > 0 and index != i]
      
      if not j:  # Ensure j is not empty
          continue
      
      j = random.choice(j)  # Choose a random valid index from j
      
      s_pair = s.copy()  # Create a copy of s to modify
      s_pair[i] = s[i] + 1
      s_pair[j] = s[j] - 1
      
      self.neighbors_list.append((s, s_pair))
    print(samples_sub, "\n", self.neighbors_list, "\n", labels_sub)
  
  def calculate_objective(self, schedule, s, d, q):
    s = service_time_with_no_shows(s, q) # Adjust service times distribution for no-shows
    sp = np.array([1], dtype=np.int64) # Set probability of first spillover time being zero to 1
    wt_list = [] # Initialize wt_list for saving all waiting times for all patients in the schedule
    ewt = 0 # Initialize sum of expected waiting times
    for x in schedule: # For each interval -
      if(x == 0): # In case there are no patients,
        wt_temp = [np.array(sp)] # the spillover from the previous interval is recorded,
        wt_list.append([]) # but there are no waiting times.
        sp = [] # Initialize the spillover time distribution 
        sp.append(np.sum(wt_temp[-1][:d+1])) # All the work from the previous interval's spillover that could not be processed will be added to the this interval's spillover.
        sp[1:] = wt_temp[-1][d+1:]
      else:
        wt_temp = [np.array(sp)] # Initialize wt_temp for saving all waiting times for all patients in the interval. The first patient has to wait for the spillover work from the previous period.
        ewt += np.dot(range(len(sp)), sp) # Add waiting time for first patient in interval
        for i in range(x-1): # For each patient
          wt = np.convolve(wt_temp[i], s) # Calculate the waiting time distribution
          wt_temp.append(wt)
          ewt += np.dot(range(len(wt)), wt)
        wt_list.append(wt_temp)
        sp = []
        sp.append(np.sum(np.convolve(wt_temp[-1],s)[:d+1])) # Calculate the spillover
        sp[1:] = np.convolve(wt_temp[-1],s)[d+1:]
    print(f"Schedule: {schedule}, ewt = {ewt}")


    
with open('./experiments/data.pickle', 'rb') as file:
  sch_data: ScheduleData = pickle.load(file)
  
sch_data.description
sch_data.describe_data()
sch_data.create_pair_list(16000)
sch_data.create_neighbors_list(10)
test_sch = sch_data.neighbors_list[1][0]
sch_data.calculate_objective(test_sch, [0.0, 0.27, 0.28, 0.2, 0.15, 0.1], 3, 0.2)
N = 10
T = 7

Samples
      x_0  x_1  x_2  x_3  x_4  x_5  x_6
7998    8    1    0    1    0    0    0
7999    8    1    1    0    0    0    0
8000    8    2    0    0    0    0    0
8001    9    0    0    0    0    0    1
8002    9    0    0    0    0    1    0
8003    9    0    0    0    1    0    0
8004    9    0    0    1    0    0    0
8005    9    0    1    0    0    0    0
8006    9    1    0    0    0    0    0
8007   10    0    0    0    0    0    0

Labels
        ew_0       ew_1       ew_2      ew_3      ew_4      ew_5    ew_6  \
7998  56.672  13.192158   0.000000  9.241482  0.000000  0.000000  0.0000   
7999  56.672  13.192158  12.217911  0.000000  0.000000  0.000000  0.0000   
8000  56.672  28.408317   0.000000  0.000000  0.000000  0.000000  0.0000   
8001  72.864   0.000000   0.000000  0.000000  0.000000  0.000000  1.9547   
8002  72.864   0.000000   0.000000  0.000000  0.000000  3.858871  0.0000   
8003  72.864   0.000000   0.000000  0.000000  6.380606  0.000000  0.0000   
8004  72.864   0.000000   0.000000  9.241482  0.000000  0.000000  0.0000   
8005  72.864   0.000000  12.217883  0.000000  0.000000  0.000000  0.0000   
8006  72.864  15.216038   0.000000  0.000000  0.000000  0.000000  0.0000   
8007  91.080   0.000000   0.000000  0.000000  0.000000  0.000000  0.0000   

            obj  obj_rank  
7998  79.105640    7964.5  
7999  82.082070    7978.0  
8000  85.080317    7987.5  
8001  74.818700    7911.0  
8002  76.722871    7945.5  
8003  79.244606    7968.0  
8004  82.105482    7982.5  
8005  85.081883    7993.0  
8006  88.080038    7998.5  
8007  91.080000    8005.0  
[([5, 1, 0, 0, 4, 0, 0], [0, 0, 0, 1, 3, 1, 5]), ([1, 1, 1, 5, 0, 1, 1], [0, 0, 0, 4, 2, 0, 4]), ([1, 5, 2, 1, 1, 0, 0], [2, 0, 1, 3, 4, 0, 0]), ([5, 0, 0, 0, 2, 2, 1], [1, 2, 5, 0, 1, 1, 0]), ([1, 3, 2, 2, 0, 0, 2], [1, 0, 3, 0, 3, 1, 2]), ([4, 2, 0, 2, 0, 0, 2], [0, 0, 2, 6, 1, 0, 1]), ([2, 2, 1, 0, 0, 5, 0], [1, 0, 0, 1, 0, 4, 4]), ([0, 2, 5, 0, 0, 2, 1], [0, 7, 0, 2, 0, 1, 0]), ([0, 4, 1, 3, 1, 1, 0], [0, 0, 3, 5, 2, 0, 0]), ([0, 4, 0, 1, 0, 2, 3], [0, 0, 0, 0, 7, 2, 1]), ([1, 0, 3, 0, 0, 6, 0], [1, 4, 0, 0, 1, 1, 3]), ([0, 5, 3, 0, 0, 0, 2], [1, 3, 1, 0, 0, 4, 1]), ([1, 4, 1, 1, 0, 3, 0], [0, 6, 1, 0, 0, 2, 1]), ([1, 0, 0, 4, 0, 5, 0], [0, 3, 4, 0, 0, 1, 2]), ([1, 1, 4, 0, 4, 0, 0], [0, 0, 1, 1, 0, 5, 3])] 
 [45.83469008087551, 31.925612362191632, 54.48618002021776, 32.93638214007037, 31.903947786090264, 37.57795844587057, 30.66425452134402, 42.27419483282949, 50.09276646866054, 30.186601175812214, 37.626163200000036, 54.50807016895987, 34.52406351149628, 45.021314621440034, 37.594862272023136] 
 [44.48115841679363, 50.77680848523785, 35.95409687035906, 42.260359046783755, 28.748981569010112, 58.61060969191143, 44.81995776000003, 67.28540959597044, 64.99914126947522, 79.08352442141943, 23.652114129588547, 28.693674219611005, 51.50006045332245, 40.97090533376224, 48.03287604838404] 
 [1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0]
[[2, 2, 1, 3, 1, 0, 1], [0, 4, 2, 0, 1, 3, 0], [1, 0, 0, 0, 2, 7, 0], [3, 0, 0, 0, 2, 5, 0], [0, 6, 0, 0, 1, 0, 3], [0, 0, 0, 5, 4, 0, 1], [0, 0, 0, 0, 1, 1, 8], [3, 1, 2, 3, 0, 0, 1], [3, 0, 1, 1, 1, 4, 0], [1, 3, 1, 4, 1, 0, 0]] 
 [([2, 2, 1, 3, 1, 0, 1], [2, 2, 1, 3, 0, 0, 2]), ([0, 4, 2, 0, 1, 3, 0], [0, 5, 1, 0, 1, 3, 0]), ([1, 0, 0, 0, 2, 7, 0], [1, 0, 0, 0, 3, 6, 0]), ([3, 0, 0, 0, 2, 5, 0], [2, 0, 0, 1, 2, 5, 0]), ([0, 6, 0, 0, 1, 0, 3], [1, 6, 0, 0, 0, 0, 3]), ([0, 0, 0, 5, 4, 0, 1], [0, 0, 0, 6, 4, 0, 0]), ([0, 0, 0, 0, 1, 1, 8], [0, 0, 0, 0, 2, 0, 8]), ([3, 1, 2, 3, 0, 0, 1], [3, 1, 2, 2, 0, 0, 2]), ([3, 0, 1, 1, 1, 4, 0], [3, 0, 1, 1, 1, 3, 1]), ([1, 3, 1, 4, 1, 0, 0], [2, 3, 1, 4, 0, 0, 0])] 
 5963    29.935620
2689    42.942470
3029    54.867392
6312    35.795831
2882    43.738361
248     70.151718
12      60.141760
6793    35.779291
6450    24.214863
4681    41.492811
Name: obj, dtype: float64
Schedule: [0, 4, 2, 0, 1, 3, 0], ewt = 42.94246990549784

Flow chart of objective calculation method

flowchart TD
    A[Start] --> B[Adjust service times distribution for no-shows]
    B --> C[Set probability of first spillover time being zero to 1]
    C --> D[Initialize wt_list and ewt </br> for saving waiting times distributions and </br>accumulated expected waiting time]
    D --> E[Loop through each interval in schedule]
    
    E --> F{Is x == 0?}
    
    F -->|Yes| G[Record spillover from previous interval]
    G --> H[Append empty list to wt_list]
    H --> I[Initialize spillover time distribution]
    I --> J[Calculate spillover for next interval]
    J --> K[Continue to next interval]

    F -->|No| L[Initialize wt_temp for current interval]
    L --> M[Add spillover from previous interval </br>as waiting time for first patient in current interval]
    M --> N[Loop through each patient in interval]
    
    N --> O[Calculate waiting time distribution for current patient]
    O --> P[Append waiting time distribution to wt_temp]
    P --> Q[Add expected waiting time for current patient to ewt]
    Q --> R[Append wt_temp to wt_list]
    
    R --> S[Initialize spillover time distribution]
    S --> T[Calculate spillover for next interval]
    T --> K[Continue to next interval]
    
    K --> U[Print wt_list]
    U --> V[Print schedule and ewt]
    V --> W[End]

Prepare data

# Prepare the dataset
X = []
for pair in sch_data.pair_list:
    X.append(pair[0] + pair[1])

X = np.array(X)
y = np.array(sch_data.lables_rank)

# Split the dataset into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

Train and evaluate model

flowchart TD
    A[Start] --> B[Initialize StratifiedKFold]
    B --> C[Initialize XGBClassifier]
    C --> D[Set results as empty list]
    D --> E[Loop through each split of cv split]
    E --> F[Get train and test indices]
    F --> G[Split X and y into X_train, X_test, y_train, y_test]
    G --> H[Clone the classifier]
    H --> I[Call fit_and_score function]
    I --> J[Fit the estimator]
    J --> K[Score on training set]
    J --> L[Score on test set]
    K --> M[Return estimator, train_score, test_score]
    L --> M
    M --> N[Append the results]
    N --> E
    E --> O[Loop ends]
    O --> P[Print results]
    P --> Q[End]

def fit_and_score(estimator, X_train, X_test, y_train, y_test):
    """Fit the estimator on the train set and score it on both sets"""
    estimator.fit(X_train, y_train, eval_set=[(X_test, y_test)])

    train_score = estimator.score(X_train, y_train)
    test_score = estimator.score(X_test, y_test)

    return estimator, train_score, test_score


cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=94)

# Initialize the XGBClassifier without early stopping here
clf = xgb.XGBClassifier(
    tree_method="hist",
    max_depth=6,
    min_child_weight=1,
    gamma=0.1,
    subsample=0.8,
    colsample_bytree=0.8,
    learning_rate=0.1,
    n_estimators=100,
    early_stopping_rounds=10
)

results = []

for train_idx, test_idx in cv.split(X, y):
    X_train, X_test = X[train_idx], X[test_idx]
    y_train, y_test = y[train_idx], y[test_idx]
    
    est, train_score, test_score = fit_and_score(
        clone(clf), X_train, X_test, y_train, y_test
    )
    results.append((est, train_score, test_score))

# Print results
for i, (est, train_score, test_score) in enumerate(results):
    print(f"Fold {i+1} - Train Score: {train_score:.4f}, Test Score: {test_score:.4f}")
[0] validation_0-logloss:0.66704
[1] validation_0-logloss:0.64197
[2] validation_0-logloss:0.62004
[3] validation_0-logloss:0.60129
[4] validation_0-logloss:0.58462
[5] validation_0-logloss:0.56863
[6] validation_0-logloss:0.55463
[7] validation_0-logloss:0.54278
[8] validation_0-logloss:0.53183
[9] validation_0-logloss:0.52202
[10]    validation_0-logloss:0.51265
[11]    validation_0-logloss:0.50330
[12]    validation_0-logloss:0.49556
[13]    validation_0-logloss:0.48886
[14]    validation_0-logloss:0.48178
[15]    validation_0-logloss:0.47352
[16]    validation_0-logloss:0.46686
[17]    validation_0-logloss:0.45823
[18]    validation_0-logloss:0.45144
[19]    validation_0-logloss:0.44638
[20]    validation_0-logloss:0.44170
[21]    validation_0-logloss:0.43606
[22]    validation_0-logloss:0.43121
[23]    validation_0-logloss:0.42626
[24]    validation_0-logloss:0.42026
[25]    validation_0-logloss:0.41398
[26]    validation_0-logloss:0.40977
[27]    validation_0-logloss:0.40508
[28]    validation_0-logloss:0.40069
[29]    validation_0-logloss:0.39565
[30]    validation_0-logloss:0.39144
[31]    validation_0-logloss:0.38700
[32]    validation_0-logloss:0.38395
[33]    validation_0-logloss:0.38027
[34]    validation_0-logloss:0.37639
[35]    validation_0-logloss:0.37305
[36]    validation_0-logloss:0.36950
[37]    validation_0-logloss:0.36542
[38]    validation_0-logloss:0.36197
[39]    validation_0-logloss:0.35907
[40]    validation_0-logloss:0.35659
[41]    validation_0-logloss:0.35246
[42]    validation_0-logloss:0.35002
[43]    validation_0-logloss:0.34677
[44]    validation_0-logloss:0.34350
[45]    validation_0-logloss:0.34061
[46]    validation_0-logloss:0.33761
[47]    validation_0-logloss:0.33511
[48]    validation_0-logloss:0.33288
[49]    validation_0-logloss:0.33082
[50]    validation_0-logloss:0.32858
[51]    validation_0-logloss:0.32658
[52]    validation_0-logloss:0.32332
[53]    validation_0-logloss:0.32046
[54]    validation_0-logloss:0.31715
[55]    validation_0-logloss:0.31561
[56]    validation_0-logloss:0.31377
[57]    validation_0-logloss:0.31075
[58]    validation_0-logloss:0.30913
[59]    validation_0-logloss:0.30729
[60]    validation_0-logloss:0.30555
[61]    validation_0-logloss:0.30361
[62]    validation_0-logloss:0.30166
[63]    validation_0-logloss:0.29909
[64]    validation_0-logloss:0.29692
[65]    validation_0-logloss:0.29539
[66]    validation_0-logloss:0.29365
[67]    validation_0-logloss:0.29163
[68]    validation_0-logloss:0.28929
[69]    validation_0-logloss:0.28798
[70]    validation_0-logloss:0.28646
[71]    validation_0-logloss:0.28538
[72]    validation_0-logloss:0.28410
[73]    validation_0-logloss:0.28237
[74]    validation_0-logloss:0.28089
[75]    validation_0-logloss:0.27960
[76]    validation_0-logloss:0.27821
[77]    validation_0-logloss:0.27676
[78]    validation_0-logloss:0.27519
[79]    validation_0-logloss:0.27396
[80]    validation_0-logloss:0.27272
[81]    validation_0-logloss:0.27075
[82]    validation_0-logloss:0.26935
[83]    validation_0-logloss:0.26826
[84]    validation_0-logloss:0.26589
[85]    validation_0-logloss:0.26479
[86]    validation_0-logloss:0.26373
[87]    validation_0-logloss:0.26270
[88]    validation_0-logloss:0.26089
[89]    validation_0-logloss:0.25961
[90]    validation_0-logloss:0.25764
[91]    validation_0-logloss:0.25653
[92]    validation_0-logloss:0.25540
[93]    validation_0-logloss:0.25509
[94]    validation_0-logloss:0.25369
[95]    validation_0-logloss:0.25280
[96]    validation_0-logloss:0.25173
[97]    validation_0-logloss:0.25119
[98]    validation_0-logloss:0.24985
[99]    validation_0-logloss:0.24888
[0] validation_0-logloss:0.66754
[1] validation_0-logloss:0.64366
[2] validation_0-logloss:0.62159
[3] validation_0-logloss:0.60301
[4] validation_0-logloss:0.58811
[5] validation_0-logloss:0.57304
[6] validation_0-logloss:0.55969
[7] validation_0-logloss:0.54666
[8] validation_0-logloss:0.53685
[9] validation_0-logloss:0.52678
[10]    validation_0-logloss:0.51696
[11]    validation_0-logloss:0.50823
[12]    validation_0-logloss:0.49949
[13]    validation_0-logloss:0.49091
[14]    validation_0-logloss:0.48345
[15]    validation_0-logloss:0.47603
[16]    validation_0-logloss:0.46972
[17]    validation_0-logloss:0.46359
[18]    validation_0-logloss:0.45784
[19]    validation_0-logloss:0.45178
[20]    validation_0-logloss:0.44530
[21]    validation_0-logloss:0.43673
[22]    validation_0-logloss:0.43163
[23]    validation_0-logloss:0.42648
[24]    validation_0-logloss:0.42185
[25]    validation_0-logloss:0.41540
[26]    validation_0-logloss:0.41039
[27]    validation_0-logloss:0.40655
[28]    validation_0-logloss:0.40221
[29]    validation_0-logloss:0.39804
[30]    validation_0-logloss:0.39286
[31]    validation_0-logloss:0.38754
[32]    validation_0-logloss:0.38334
[33]    validation_0-logloss:0.37859
[34]    validation_0-logloss:0.37467
[35]    validation_0-logloss:0.37123
[36]    validation_0-logloss:0.36770
[37]    validation_0-logloss:0.36266
[38]    validation_0-logloss:0.35911
[39]    validation_0-logloss:0.35462
[40]    validation_0-logloss:0.35186
[41]    validation_0-logloss:0.34792
[42]    validation_0-logloss:0.34556
[43]    validation_0-logloss:0.34182
[44]    validation_0-logloss:0.33925
[45]    validation_0-logloss:0.33623
[46]    validation_0-logloss:0.33320
[47]    validation_0-logloss:0.32991
[48]    validation_0-logloss:0.32819
[49]    validation_0-logloss:0.32630
[50]    validation_0-logloss:0.32388
[51]    validation_0-logloss:0.32128
[52]    validation_0-logloss:0.31842
[53]    validation_0-logloss:0.31606
[54]    validation_0-logloss:0.31387
[55]    validation_0-logloss:0.31268
[56]    validation_0-logloss:0.30981
[57]    validation_0-logloss:0.30799
[58]    validation_0-logloss:0.30569
[59]    validation_0-logloss:0.30370
[60]    validation_0-logloss:0.30175
[61]    validation_0-logloss:0.29983
[62]    validation_0-logloss:0.29820
[63]    validation_0-logloss:0.29610
[64]    validation_0-logloss:0.29431
[65]    validation_0-logloss:0.29239
[66]    validation_0-logloss:0.29053
[67]    validation_0-logloss:0.28896
[68]    validation_0-logloss:0.28771
[69]    validation_0-logloss:0.28484
[70]    validation_0-logloss:0.28368
[71]    validation_0-logloss:0.28234
[72]    validation_0-logloss:0.28114
[73]    validation_0-logloss:0.28019
[74]    validation_0-logloss:0.27826
[75]    validation_0-logloss:0.27630
[76]    validation_0-logloss:0.27476
[77]    validation_0-logloss:0.27332
[78]    validation_0-logloss:0.27228
[79]    validation_0-logloss:0.27049
[80]    validation_0-logloss:0.26843
[81]    validation_0-logloss:0.26696
[82]    validation_0-logloss:0.26552
[83]    validation_0-logloss:0.26416
[84]    validation_0-logloss:0.26295
[85]    validation_0-logloss:0.26173
[86]    validation_0-logloss:0.26066
[87]    validation_0-logloss:0.25959
[88]    validation_0-logloss:0.25843
[89]    validation_0-logloss:0.25704
[90]    validation_0-logloss:0.25605
[91]    validation_0-logloss:0.25484
[92]    validation_0-logloss:0.25361
[93]    validation_0-logloss:0.25264
[94]    validation_0-logloss:0.25173
[95]    validation_0-logloss:0.25067
[96]    validation_0-logloss:0.24935
[97]    validation_0-logloss:0.24842
[98]    validation_0-logloss:0.24678
[99]    validation_0-logloss:0.24569
[0] validation_0-logloss:0.66716
[1] validation_0-logloss:0.64197
[2] validation_0-logloss:0.61964
[3] validation_0-logloss:0.60098
[4] validation_0-logloss:0.58559
[5] validation_0-logloss:0.57006
[6] validation_0-logloss:0.55697
[7] validation_0-logloss:0.54281
[8] validation_0-logloss:0.53207
[9] validation_0-logloss:0.52178
[10]    validation_0-logloss:0.51093
[11]    validation_0-logloss:0.50144
[12]    validation_0-logloss:0.49250
[13]    validation_0-logloss:0.48505
[14]    validation_0-logloss:0.47757
[15]    validation_0-logloss:0.46979
[16]    validation_0-logloss:0.46208
[17]    validation_0-logloss:0.45400
[18]    validation_0-logloss:0.44816
[19]    validation_0-logloss:0.44282
[20]    validation_0-logloss:0.43705
[21]    validation_0-logloss:0.43032
[22]    validation_0-logloss:0.42467
[23]    validation_0-logloss:0.41924
[24]    validation_0-logloss:0.41367
[25]    validation_0-logloss:0.40847
[26]    validation_0-logloss:0.40533
[27]    validation_0-logloss:0.40131
[28]    validation_0-logloss:0.39791
[29]    validation_0-logloss:0.39193
[30]    validation_0-logloss:0.38833
[31]    validation_0-logloss:0.38346
[32]    validation_0-logloss:0.37982
[33]    validation_0-logloss:0.37538
[34]    validation_0-logloss:0.37042
[35]    validation_0-logloss:0.36736
[36]    validation_0-logloss:0.36423
[37]    validation_0-logloss:0.35971
[38]    validation_0-logloss:0.35530
[39]    validation_0-logloss:0.35203
[40]    validation_0-logloss:0.34851
[41]    validation_0-logloss:0.34415
[42]    validation_0-logloss:0.34082
[43]    validation_0-logloss:0.33693
[44]    validation_0-logloss:0.33408
[45]    validation_0-logloss:0.33168
[46]    validation_0-logloss:0.32832
[47]    validation_0-logloss:0.32464
[48]    validation_0-logloss:0.32291
[49]    validation_0-logloss:0.32065
[50]    validation_0-logloss:0.31809
[51]    validation_0-logloss:0.31611
[52]    validation_0-logloss:0.31384
[53]    validation_0-logloss:0.31161
[54]    validation_0-logloss:0.30955
[55]    validation_0-logloss:0.30677
[56]    validation_0-logloss:0.30409
[57]    validation_0-logloss:0.30121
[58]    validation_0-logloss:0.29908
[59]    validation_0-logloss:0.29649
[60]    validation_0-logloss:0.29427
[61]    validation_0-logloss:0.29233
[62]    validation_0-logloss:0.29096
[63]    validation_0-logloss:0.28932
[64]    validation_0-logloss:0.28800
[65]    validation_0-logloss:0.28596
[66]    validation_0-logloss:0.28452
[67]    validation_0-logloss:0.28181
[68]    validation_0-logloss:0.28003
[69]    validation_0-logloss:0.27823
[70]    validation_0-logloss:0.27653
[71]    validation_0-logloss:0.27482
[72]    validation_0-logloss:0.27305
[73]    validation_0-logloss:0.27103
[74]    validation_0-logloss:0.26926
[75]    validation_0-logloss:0.26845
[76]    validation_0-logloss:0.26723
[77]    validation_0-logloss:0.26573
[78]    validation_0-logloss:0.26405
[79]    validation_0-logloss:0.26299
[80]    validation_0-logloss:0.26201
[81]    validation_0-logloss:0.26009
[82]    validation_0-logloss:0.25872
[83]    validation_0-logloss:0.25748
[84]    validation_0-logloss:0.25512
[85]    validation_0-logloss:0.25418
[86]    validation_0-logloss:0.25303
[87]    validation_0-logloss:0.25181
[88]    validation_0-logloss:0.25062
[89]    validation_0-logloss:0.24941
[90]    validation_0-logloss:0.24796
[91]    validation_0-logloss:0.24648
[92]    validation_0-logloss:0.24520
[93]    validation_0-logloss:0.24342
[94]    validation_0-logloss:0.24220
[95]    validation_0-logloss:0.24141
[96]    validation_0-logloss:0.24023
[97]    validation_0-logloss:0.23914
[98]    validation_0-logloss:0.23845
[99]    validation_0-logloss:0.23752
[0] validation_0-logloss:0.66855
[1] validation_0-logloss:0.64549
[2] validation_0-logloss:0.62624
[3] validation_0-logloss:0.60934
[4] validation_0-logloss:0.59387
[5] validation_0-logloss:0.57887
[6] validation_0-logloss:0.56657
[7] validation_0-logloss:0.55343
[8] validation_0-logloss:0.54458
[9] validation_0-logloss:0.53453
[10]    validation_0-logloss:0.52476
[11]    validation_0-logloss:0.51582
[12]    validation_0-logloss:0.50666
[13]    validation_0-logloss:0.49987
[14]    validation_0-logloss:0.49205
[15]    validation_0-logloss:0.48464
[16]    validation_0-logloss:0.47812
[17]    validation_0-logloss:0.46834
[18]    validation_0-logloss:0.46099
[19]    validation_0-logloss:0.45513
[20]    validation_0-logloss:0.45025
[21]    validation_0-logloss:0.44236
[22]    validation_0-logloss:0.43700
[23]    validation_0-logloss:0.43229
[24]    validation_0-logloss:0.42635
[25]    validation_0-logloss:0.42189
[26]    validation_0-logloss:0.41804
[27]    validation_0-logloss:0.41100
[28]    validation_0-logloss:0.40673
[29]    validation_0-logloss:0.40305
[30]    validation_0-logloss:0.39805
[31]    validation_0-logloss:0.39435
[32]    validation_0-logloss:0.39012
[33]    validation_0-logloss:0.38675
[34]    validation_0-logloss:0.38231
[35]    validation_0-logloss:0.37796
[36]    validation_0-logloss:0.37341
[37]    validation_0-logloss:0.36886
[38]    validation_0-logloss:0.36559
[39]    validation_0-logloss:0.36265
[40]    validation_0-logloss:0.35960
[41]    validation_0-logloss:0.35668
[42]    validation_0-logloss:0.35365
[43]    validation_0-logloss:0.34970
[44]    validation_0-logloss:0.34689
[45]    validation_0-logloss:0.34424
[46]    validation_0-logloss:0.34120
[47]    validation_0-logloss:0.33856
[48]    validation_0-logloss:0.33524
[49]    validation_0-logloss:0.33194
[50]    validation_0-logloss:0.32913
[51]    validation_0-logloss:0.32666
[52]    validation_0-logloss:0.32481
[53]    validation_0-logloss:0.32262
[54]    validation_0-logloss:0.31989
[55]    validation_0-logloss:0.31722
[56]    validation_0-logloss:0.31539
[57]    validation_0-logloss:0.31337
[58]    validation_0-logloss:0.31084
[59]    validation_0-logloss:0.30813
[60]    validation_0-logloss:0.30521
[61]    validation_0-logloss:0.30331
[62]    validation_0-logloss:0.30162
[63]    validation_0-logloss:0.29953
[64]    validation_0-logloss:0.29724
[65]    validation_0-logloss:0.29560
[66]    validation_0-logloss:0.29387
[67]    validation_0-logloss:0.29233
[68]    validation_0-logloss:0.28957
[69]    validation_0-logloss:0.28768
[70]    validation_0-logloss:0.28600
[71]    validation_0-logloss:0.28472
[72]    validation_0-logloss:0.28296
[73]    validation_0-logloss:0.28075
[74]    validation_0-logloss:0.27859
[75]    validation_0-logloss:0.27731
[76]    validation_0-logloss:0.27532
[77]    validation_0-logloss:0.27388
[78]    validation_0-logloss:0.27285
[79]    validation_0-logloss:0.27144
[80]    validation_0-logloss:0.27034
[81]    validation_0-logloss:0.26838
[82]    validation_0-logloss:0.26693
[83]    validation_0-logloss:0.26556
[84]    validation_0-logloss:0.26442
[85]    validation_0-logloss:0.26327
[86]    validation_0-logloss:0.26220
[87]    validation_0-logloss:0.26138
[88]    validation_0-logloss:0.26030
[89]    validation_0-logloss:0.25889
[90]    validation_0-logloss:0.25731
[91]    validation_0-logloss:0.25634
[92]    validation_0-logloss:0.25515
[93]    validation_0-logloss:0.25355
[94]    validation_0-logloss:0.25144
[95]    validation_0-logloss:0.25061
[96]    validation_0-logloss:0.24973
[97]    validation_0-logloss:0.24850
[98]    validation_0-logloss:0.24769
[99]    validation_0-logloss:0.24648
[0] validation_0-logloss:0.66795
[1] validation_0-logloss:0.64307
[2] validation_0-logloss:0.62180
[3] validation_0-logloss:0.60377
[4] validation_0-logloss:0.58766
[5] validation_0-logloss:0.57282
[6] validation_0-logloss:0.55964
[7] validation_0-logloss:0.54693
[8] validation_0-logloss:0.53543
[9] validation_0-logloss:0.52383
[10]    validation_0-logloss:0.51443
[11]    validation_0-logloss:0.50582
[12]    validation_0-logloss:0.49701
[13]    validation_0-logloss:0.48842
[14]    validation_0-logloss:0.48054
[15]    validation_0-logloss:0.47320
[16]    validation_0-logloss:0.46709
[17]    validation_0-logloss:0.46060
[18]    validation_0-logloss:0.45414
[19]    validation_0-logloss:0.44762
[20]    validation_0-logloss:0.44270
[21]    validation_0-logloss:0.43368
[22]    validation_0-logloss:0.42713
[23]    validation_0-logloss:0.42058
[24]    validation_0-logloss:0.41481
[25]    validation_0-logloss:0.41125
[26]    validation_0-logloss:0.40636
[27]    validation_0-logloss:0.40197
[28]    validation_0-logloss:0.39744
[29]    validation_0-logloss:0.39349
[30]    validation_0-logloss:0.38787
[31]    validation_0-logloss:0.38425
[32]    validation_0-logloss:0.37907
[33]    validation_0-logloss:0.37483
[34]    validation_0-logloss:0.37053
[35]    validation_0-logloss:0.36611
[36]    validation_0-logloss:0.36280
[37]    validation_0-logloss:0.35829
[38]    validation_0-logloss:0.35528
[39]    validation_0-logloss:0.35034
[40]    validation_0-logloss:0.34683
[41]    validation_0-logloss:0.34295
[42]    validation_0-logloss:0.33934
[43]    validation_0-logloss:0.33620
[44]    validation_0-logloss:0.33224
[45]    validation_0-logloss:0.32995
[46]    validation_0-logloss:0.32745
[47]    validation_0-logloss:0.32488
[48]    validation_0-logloss:0.32174
[49]    validation_0-logloss:0.31918
[50]    validation_0-logloss:0.31699
[51]    validation_0-logloss:0.31453
[52]    validation_0-logloss:0.31189
[53]    validation_0-logloss:0.30938
[54]    validation_0-logloss:0.30604
[55]    validation_0-logloss:0.30330
[56]    validation_0-logloss:0.30019
[57]    validation_0-logloss:0.29851
[58]    validation_0-logloss:0.29711
[59]    validation_0-logloss:0.29530
[60]    validation_0-logloss:0.29295
[61]    validation_0-logloss:0.29074
[62]    validation_0-logloss:0.28918
[63]    validation_0-logloss:0.28726
[64]    validation_0-logloss:0.28528
[65]    validation_0-logloss:0.28291
[66]    validation_0-logloss:0.28100
[67]    validation_0-logloss:0.27956
[68]    validation_0-logloss:0.27750
[69]    validation_0-logloss:0.27565
[70]    validation_0-logloss:0.27258
[71]    validation_0-logloss:0.27093
[72]    validation_0-logloss:0.26875
[73]    validation_0-logloss:0.26720
[74]    validation_0-logloss:0.26488
[75]    validation_0-logloss:0.26316
[76]    validation_0-logloss:0.26194
[77]    validation_0-logloss:0.26030
[78]    validation_0-logloss:0.25878
[79]    validation_0-logloss:0.25753
[80]    validation_0-logloss:0.25643
[81]    validation_0-logloss:0.25464
[82]    validation_0-logloss:0.25307
[83]    validation_0-logloss:0.25178
[84]    validation_0-logloss:0.25069
[85]    validation_0-logloss:0.24891
[86]    validation_0-logloss:0.24698
[87]    validation_0-logloss:0.24570
[88]    validation_0-logloss:0.24463
[89]    validation_0-logloss:0.24359
[90]    validation_0-logloss:0.24207
[91]    validation_0-logloss:0.24063
[92]    validation_0-logloss:0.23935
[93]    validation_0-logloss:0.23778
[94]    validation_0-logloss:0.23666
[95]    validation_0-logloss:0.23569
[96]    validation_0-logloss:0.23482
[97]    validation_0-logloss:0.23359
[98]    validation_0-logloss:0.23215
[99]    validation_0-logloss:0.23141
Fold 1 - Train Score: 0.9662, Test Score: 0.9116
Fold 2 - Train Score: 0.9666, Test Score: 0.9175
Fold 3 - Train Score: 0.9677, Test Score: 0.9187
Fold 4 - Train Score: 0.9670, Test Score: 0.9113
Fold 5 - Train Score: 0.9663, Test Score: 0.9291

Test

The model seems to be having trouble ranking schedules with close similarity consistently. This could indicate model overfitting. Solutions: create training sets with close pairs or add close pairs to the existing training data set.

# Fit the model on the entire dataset
# Initialize the XGBClassifier without early stopping here
clf = xgb.XGBClassifier(
    tree_method="hist",
    max_depth=6,
    min_child_weight=1,
    gamma=0.1,
    subsample=0.8,
    colsample_bytree=0.8,
    learning_rate=0.1,
    n_estimators=100
)

clf.fit(X, y)

input_X = [
            ([9, 1, 0, 0, 0, 0, 0], [8, 2, 0, 0, 0, 0, 0]),
            ([8, 2, 0, 0, 0, 0, 0], [9, 1, 0, 0, 0, 0, 0]),
            ([3, 1, 1, 1, 0, 1, 2], [2, 1, 1, 1, 0, 1, 3]),
            ([2, 1, 1, 1, 0, 1, 3], [3, 1, 1, 1, 0, 1, 2]),
            ([3, 1, 1, 1, 0, 1, 2], [9, 1, 0, 0, 0, 0, 0]),
            ([8, 2, 0, 0, 0, 0, 0], [3, 1, 1, 1, 0, 1, 2])
            ]
X_new = []
for pair in input_X:
    X_new.append(pair[0] + pair[1])
    
# Predict the target for new data
y_pred = clf.predict(X_new)

# If you want to get the probability estimates
y_pred_proba = clf.predict_proba(X_new)

print(f"y_pred = {y_pred}, \ny_pred_proba = \n{y_pred_proba}")
y_pred = [0 0 1 0 0 1], 
y_pred_proba = 
[[0.76242477 0.23757525]
 [0.69124234 0.30875763]
 [0.3242296  0.6757704 ]
 [0.5743403  0.42565972]
 [0.99786097 0.00213903]
 [0.00121516 0.99878484]]