Coverage for src/edelweiss/tf_utils.py: 100%

18 statements  

« prev     ^ index     » next       coverage.py v7.6.7, created at 2024-11-18 17:09 +0000

1# Copyright (C) 2024 ETH Zurich 

2# Institute for Particle Physics and Astrophysics 

3# Author: Silvan Fischbacher 

4# created: Tue Feb 27 2024 

5 

6 

7import tensorflow as tf 

8from tqdm import tqdm 

9 

10 

11class EpochProgressCallback(tf.keras.callbacks.Callback): 

12 """ 

13 Class to implement a tqdm progress bar over epochs, written by ChatGPT, provided by 

14 Arne Thomsen 

15 """ 

16 

17 def __init__(self, total_epochs): 

18 super().__init__() 

19 self.total_epochs = total_epochs 

20 self.pbar = None 

21 

22 def on_train_begin(self, logs=None): 

23 self.pbar = tqdm(total=self.total_epochs, desc="epoch") 

24 

25 def on_epoch_end(self, epoch, logs=None): 

26 self.pbar.update(1) 

27 

28 if logs is not None: 

29 loss = logs.get("loss") 

30 val_loss = logs.get("val_loss") 

31 lr = logs.get("lr") 

32 

33 self.pbar.set_postfix({"loss": loss, "val_loss": val_loss, "lr": lr}) 

34 

35 def on_train_end(self, logs=None): 

36 self.pbar.close()