Coverage for src/edelweiss/tf_utils.py: 100%
18 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-18 17:09 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-18 17:09 +0000
1# Copyright (C) 2024 ETH Zurich
2# Institute for Particle Physics and Astrophysics
3# Author: Silvan Fischbacher
4# created: Tue Feb 27 2024
7import tensorflow as tf
8from tqdm import tqdm
11class EpochProgressCallback(tf.keras.callbacks.Callback):
12 """
13 Class to implement a tqdm progress bar over epochs, written by ChatGPT, provided by
14 Arne Thomsen
15 """
17 def __init__(self, total_epochs):
18 super().__init__()
19 self.total_epochs = total_epochs
20 self.pbar = None
22 def on_train_begin(self, logs=None):
23 self.pbar = tqdm(total=self.total_epochs, desc="epoch")
25 def on_epoch_end(self, epoch, logs=None):
26 self.pbar.update(1)
28 if logs is not None:
29 loss = logs.get("loss")
30 val_loss = logs.get("val_loss")
31 lr = logs.get("lr")
33 self.pbar.set_postfix({"loss": loss, "val_loss": val_loss, "lr": lr})
35 def on_train_end(self, logs=None):
36 self.pbar.close()