Skip to content

当LSTMPredictStock/config.json内的epochs大于1时出现的bug #23

@DanielWen-Takuya

Description

@DanielWen-Takuya

我如何发现这个问题

  • 在详细阅读完README后,我开始自己进行模型的训练,在注意到提供的数据对于每一支股票都只训练了1个epoch且loss大概为0.0020-0.1,自然而然我就想把epoch调高
  • 通过在LSTMPredictStock/config.json内的training - epochs调成2,程序报错:
[Model] Training Started
[Model] 2 epochs, 8 batch size, 29 batches per epoch
2025-03-09 22:12:11.276718: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2025-03-09 22:12:11.276967: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 1996205000 Hz
Epoch 1/2
29/29 [==============================] - 2s 30ms/step - loss: 0.0024
Epoch 2/2
2025-03-09 22:12:13.764939: W tensorflow/core/framework/op_kernel.cc:1751] Invalid argument: TypeError: `generator` yielded an element that could not be converted to the expected type. The expected type was float32, but the yielded element was [array([[ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.00276253, -0.00033106,  0.00230376,  0.00154365],
       [ 0.00455639,  0.00548025,  0.00636496,  0.00553156],
       [ 0.00534874,  0.00714018,  0.00698772,  0.00492981],
       [ 0.00480491,  0.00315387,  0.00463258,  0.00367912],
       [-0.00418113, -0.00538486, -0.00376593, -0.00524791],
       [ 0.00094419,  0.00176553,  0.00281446,  0.00227668],
       [ 0.00790458,  0.00602065,  0.00821311,  0.00788086],
       [ 0.00727708,  0.00890354,  0.01185306,  0.00880289],
       [ 0.00073471, -0.00148558,  0.00091246, -0.00060734],
       [ 0.01349254,  0.01021229,  0.01257114,  0.0108043 ],
       [ 0.02372086,  0.02297548,  0.02319355,  0.01668251],
       [ 0.02945446,  0.02684999,  0.02926323,  0.02902177],
       [ 0.02817995,  0.02719994,  0.02788556,  0.02784433],
       [ 0.03695807,  0.02922592,  0.03577212,  0.03119914],
       [ 0.0326502 ,  0.03730632,  0.03973274,  0.03421444],
       [ 0.03706311,  0.03287932,  0.03660752,  0.03450554],
       [ 0.03988483,  0.03439958,  0.03887289,  0.03616712],
       [ 0.03017368,  0.03424257,  0.03532363,  0.02857177],
       [ 0.03855516,  0.03455752,  0.03746458,  0.0367456 ],
       [ 0.03832089,  0.0379483 ,  0.03854882,  0.03796184],
       [ 0.04709962,  0.04396957,  0.04603146,  0.04606553],
       [ 0.04521404,  0.04305011,  0.04424058,  0.04147213],
       [ 0.03685116,  0.0401923 ,  0.04021992,  0.0358785 ],
       [ 0.0474408 ,  0.04420525,  0.04624286,  0.04387761],
       [ 0.04987269,  0.04401355,  0.04889452,  0.04502898],
       [ 0.02905999,  0.03601027,  0.03544558,  0.02994194],
       [ 0.02782917,  0.02909307,  0.0291289 ,  0.02555896],
       [ 0.03008661,  0.02874931,  0.02939694,  0.02995715]])
 array([[ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.00178892,  0.00581323,  0.00405187,  0.00398177],
       [ 0.00257908,  0.00747371,  0.0046732 ,  0.00338094],
       [ 0.00203675,  0.00348608,  0.00232347,  0.00213218],
       [-0.00692453, -0.00505548, -0.00605573, -0.00678109],
       [-0.00181333,  0.00209729,  0.00050953,  0.0007319 ],
       [ 0.00512789,  0.00635381,  0.00589577,  0.00632745],
       [ 0.00450212,  0.00923766,  0.00952735,  0.00724806],
       [-0.00202223, -0.0011549 , -0.00138809, -0.00214768],
       [ 0.01070045,  0.01054684,  0.01024379,  0.00924638],
       [ 0.0209006 ,  0.02331426,  0.02084178,  0.01511553],
       [ 0.0266184 ,  0.02719005,  0.02689751,  0.02743577],
       [ 0.0253474 ,  0.02754012,  0.02552301,  0.02626014],
       [ 0.03410133,  0.02956677,  0.03339144,  0.02960979],
       [ 0.02980533,  0.03764984,  0.03734295,  0.03262044],
       [ 0.03420609,  0.03322137,  0.03422491,  0.03291109],
       [ 0.03702003,  0.03474214,  0.03648508,  0.03457011],
       [ 0.02733564,  0.03458508,  0.03294398,  0.02698647],
       [ 0.03569402,  0.03490013,  0.03508   ,  0.0351477 ],
       [ 0.0354604 ,  0.03829204,  0.03616176,  0.03636206],
       [ 0.04421495,  0.0443153 ,  0.0436272 ,  0.04445326],
       [ 0.04233456,  0.04339553,  0.04184044,  0.03986694],
       [ 0.03399472,  0.04053678,  0.03782902,  0.03428193],
       [ 0.04455518,  0.04455105,  0.04383812,  0.04226871],
       [ 0.04698038,  0.04435929,  0.04648368,  0.04341831],
       [ 0.02622502,  0.03635336,  0.03306565,  0.02835452],
       [ 0.02499759,  0.02943387,  0.02676349,  0.0239783 ],
       [ 0.0272488 ,  0.02909   ,  0.02703092,  0.02836971],
       [ 0.03273546,  0.03418049,  0.03225781,  0.03323893]])
 array([[ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.00078875,  0.00165088,  0.00061881, -0.00059845],
       [ 0.00024739, -0.0023137 , -0.00172143, -0.00184225],
       [-0.00869789, -0.01080589, -0.01006681, -0.01072017],
       [-0.00359582, -0.00369447, -0.00352804, -0.00323698],
       [ 0.003333  ,  0.00053746,  0.00183645,  0.00233638],
       [ 0.00270835,  0.00340464,  0.00545338,  0.00325333],
       [-0.00380434, -0.00692786, -0.00541801, -0.00610513],
       [ 0.00889562,  0.00470625,  0.00616693,  0.00524373],
       [ 0.01907755,  0.01739988,  0.01672215,  0.01108961],
       [ 0.02478514,  0.02125327,  0.02275344,  0.02336099],
       [ 0.02351641,  0.02160131,  0.02138448,  0.02219002],
       [ 0.03225471,  0.02361625,  0.02922117,  0.02552638],
       [ 0.02796638,  0.0316526 ,  0.03315673,  0.02852509],
       [ 0.03235928,  0.02724974,  0.03005128,  0.02881459],
       [ 0.0351682 ,  0.02876171,  0.03230232,  0.03046703],
       [ 0.0255011 ,  0.02860556,  0.02877551,  0.02291346],
       [ 0.03384456,  0.02891879,  0.03090292,  0.03104233],
       [ 0.03361136,  0.03229109,  0.03198031,  0.03225187],
       [ 0.04235027,  0.03827955,  0.03941562,  0.04031098],
       [ 0.04047324,  0.03736509,  0.03763607,  0.03574286],
       [ 0.03214829,  0.03452286,  0.03364083,  0.03018   ],
       [ 0.0426899 ,  0.03851393,  0.03962568,  0.0381351 ],
       [ 0.04511076,  0.03832328,  0.04226057,  0.03928014],
       [ 0.02439246,  0.03036362,  0.02889669,  0.02427609],
       [ 0.02316723,  0.02348412,  0.02261996,  0.01991722],
       [ 0.02541442,  0.02314224,  0.02288631,  0.02429122],
       [ 0.03089128,  0.02820331,  0.02809211,  0.02914113],
       [ 0.04296258,  0.04005147,  0.04172941,  0.04219798]])
 array([[ 0.        ,  0.        ,  0.        ,  0.        ],
       [-0.00054094, -0.00395805, -0.00233879, -0.00124455],
       [-0.00947917, -0.01243625, -0.01067902, -0.01012779],
       [-0.00438112, -0.00533654, -0.0041443 , -0.00264011],
       [ 0.00254224, -0.00111159,  0.00121689,  0.00293658],
       [ 0.00191809,  0.00175086,  0.00483158,  0.00385409],
       [-0.00458948, -0.0085646 , -0.0060331 , -0.00550999],
       [ 0.00810047,  0.00305033,  0.00554468,  0.00584567],
       [ 0.01827438,  0.01572304,  0.01609338,  0.01169505],
       [ 0.02397748,  0.01957008,  0.02212094,  0.02397378],
       [ 0.02270975,  0.01991755,  0.02075283,  0.02280211],
       [ 0.03144116,  0.02192917,  0.02858466,  0.02614047],
       [ 0.02715621,  0.02995227,  0.03251779,  0.02914097],
       [ 0.03154565,  0.02555666,  0.02941426,  0.02943065],
       [ 0.03435235,  0.02706615,  0.03166391,  0.03108408],
       [ 0.02469287,  0.02691025,  0.02813928,  0.02352599],
       [ 0.03302975,  0.02722297,  0.03026537,  0.03165972],
       [ 0.03279673,  0.03058971,  0.0313421 ,  0.03286999],
       [ 0.04152876,  0.03656829,  0.03877281,  0.04093392],
       [ 0.03965321,  0.03565535,  0.03699436,  0.03636306],
       [ 0.03133482,  0.0328178 ,  0.0330016 ,  0.03079687],
       [ 0.04186812,  0.03680229,  0.03898275,  0.03875674],
       [ 0.04428707,  0.03661196,  0.041616  ,  0.03990246],
       [ 0.02358511,  0.02866542,  0.02826039,  0.02488943],
       [ 0.02236083,  0.02179725,  0.02198754,  0.02052795],
       [ 0.02460626,  0.02145594,  0.02225372,  0.02490457],
       [ 0.0300788 ,  0.02650866,  0.0274563 ,  0.02975738],
       [ 0.04214059,  0.0383373 ,  0.04108517,  0.04282206]])].
TypeError: only size-1 arrays can be converted to Python scalars


The above exception was the direct cause of the following exception:

问题的解决

  • 问题非常明显出现在训练模型的数据generator,数据model的初始化和训练定义在run.py的train_model函数,其中链接到了提供数据的函数generate_train_batch(data_preprocessor.py)
  • 函数一眼看过去没有问题,而且还十分贴心在LSTM窗口超过数据量时,传输的数据会比seq_len小:”# stop-condition for a smaller final batch if data doesn't divide evenly“
  • 但此处正是问题所在,在跑第二个epoch的时候,数据正好在一开始就没了,模型没有数据,就报了一个奇怪的错
  • 两个小修改:当数据不足时,提前结束进入下一个数据循环;用while True包裹整个函数,让数据源源不断:
def generate_train_batch(self, seq_len, batch_size, normalise):
	'''Yield a generator of training data from filename on given list of cols split for train/test'''
	while True:  # to avoid generator lack of data
		i = 0
		while i < (self.len_train - seq_len):
			x_batch = []
			y_batch = []
			if i + batch_size >= (self.len_train - seq_len):
				break
			for b in range(batch_size):
				# if i >= (self.len_train - seq_len):
				# 	# stop-condition for a smaller final batch if data doesn't divide evenly
				# 	yield np.array(x_batch), np.array(y_batch)
				x, y = self._next_window(i, seq_len, normalise)
				x_batch.append(x)
				y_batch.append(y)
				i += 1
			yield np.array(x_batch), np.array(y_batch)
  • 问题解决:
[Model] Training Started
[Model] 2 epochs, 8 batch size, 29 batches per epoch
2025-03-09 22:30:46.705440: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2025-03-09 22:30:46.705698: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 1996205000 Hz
Epoch 1/2
29/29 [==============================] - 3s 31ms/step - loss: 0.0024
Epoch 2/2
29/29 [==============================] - 1s 27ms/step - loss: 0.0016
[Model] Training Completed.

后记

虽然这个代码已经有些古早了,例如里面获取股票信息的API已经不能用了,但整体而言,还是一个挺完整的项目
看到issue里有人问下载不了新数据,我这里也一并给出我这边的方案:使用新浪API
http://money.finance.sina.com.cn/quotes_service/api/json_v2.php/CN_MarketData.getKLineData?symbol=[市场][股票代码]&scale=[周期]&ma=no&datalen=[长度]
市场:sh为上证,sz为深成
周期:5、10、30、60
长度:最大为1023
获取的直接是json数据,不用像原作者一样转换文本,可直接保存csv文件。只不过为了统一其他代码,需要把json的每一个组数据的key改为原作者设定的Date,Code,Name,Open,Close,High,Low,Volume。另外,获取的json中没有Code和Name,这个直接自己硬加即可。

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions