Tôi chắc rằng có nhiều cách để làm điều này, nhưng tôi đã mày mò và nghĩ ra một phiên bản của riêng mình.
Đầu tiên, một lệnh gọi lại tùy chỉnh cho phép lấy và cập nhật lịch sử vào cuối mỗi kỷ nguyên. Trong đó tôi cũng có một cuộc gọi lại để lưu mô hình. Cả hai điều này đều tiện dụng vì nếu bạn gặp sự cố hoặc tắt máy, bạn có thể tiếp tục đào tạo vào thời điểm hoàn thành cuối cùng.
class LossHistory(Callback):
# https://stackoverflow.com/a/53653154/852795
def on_epoch_end(self, epoch, logs = None):
new_history = {}
for k, v in logs.items(): # compile new history from logs
new_history[k] = [v] # convert values into lists
current_history = loadHist(history_filename) # load history from current training
current_history = appendHist(current_history, new_history) # append the logs
saveHist(history_filename, current_history) # save history from current training
model_checkpoint = ModelCheckpoint(model_filename, verbose = 0, period = 1)
history_checkpoint = LossHistory()
callbacks_list = [model_checkpoint, history_checkpoint]
Thứ hai, đây là một số chức năng của 'helper' để thực hiện chính xác những điều họ nói là họ làm. Đây là tất cả được gọi từ cuộc LossHistory()
gọi lại.
# https://stackoverflow.com/a/54092401/852795
import json, codecs
def saveHist(path, history):
with codecs.open(path, 'w', encoding='utf-8') as f:
json.dump(history, f, separators=(',', ':'), sort_keys=True, indent=4)
def loadHist(path):
n = {} # set history to empty
if os.path.exists(path): # reload history if it exists
with codecs.open(path, 'r', encoding='utf-8') as f:
n = json.loads(f.read())
return n
def appendHist(h1, h2):
if h1 == {}:
return h2
else:
dest = {}
for key, value in h1.items():
dest[key] = value + h2[key]
return dest
Sau đó, tất cả những gì bạn cần là đặt history_filename
thành một cái gì đó giống như data/model-history.json
, cũng như đặt model_filesname
một cái gì đó giống như vậy data/model.h5
. Một tinh chỉnh cuối cùng để đảm bảo không làm rối lịch sử của bạn khi kết thúc quá trình đào tạo, giả sử bạn dừng lại và bắt đầu, cũng như dính vào các lệnh gọi lại, là thực hiện điều này:
new_history = model.fit(X_train, y_train,
batch_size = batch_size,
nb_epoch = nb_epoch,
validation_data=(X_test, y_test),
callbacks=callbacks_list)
history = appendHist(history, new_history.history)
Bất cứ khi nào bạn muốn, hãy history = loadHist(history_filename)
lấy lại lịch sử của bạn.
Sự thú vị đến từ json và các danh sách nhưng tôi không thể làm cho nó hoạt động mà không chuyển đổi nó bằng cách lặp lại. Dù sao, tôi biết rằng điều này hiệu quả bởi vì tôi đã tập trung vào nó trong nhiều ngày nay. Câu pickled.dump
trả lời tại https://stackoverflow.com/a/44674337/852795 có thể tốt hơn, nhưng tôi không biết đó là gì. Nếu tôi bỏ lỡ bất cứ điều gì ở đây hoặc bạn không thể làm cho nó hoạt động, hãy cho tôi biết.