-
Notifications
You must be signed in to change notification settings - Fork 145
Open
Description
我如何发现这个问题
- 在详细阅读完README后,我开始自己进行模型的训练,在注意到提供的数据对于每一支股票都只训练了1个epoch且loss大概为0.0020-0.1,自然而然我就想把epoch调高
- 通过在LSTMPredictStock/config.json内的training - epochs调成2,程序报错:
[Model] Training Started
[Model] 2 epochs, 8 batch size, 29 batches per epoch
2025-03-09 22:12:11.276718: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2025-03-09 22:12:11.276967: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 1996205000 Hz
Epoch 1/2
29/29 [==============================] - 2s 30ms/step - loss: 0.0024
Epoch 2/2
2025-03-09 22:12:13.764939: W tensorflow/core/framework/op_kernel.cc:1751] Invalid argument: TypeError: `generator` yielded an element that could not be converted to the expected type. The expected type was float32, but the yielded element was [array([[ 0. , 0. , 0. , 0. ],
[ 0.00276253, -0.00033106, 0.00230376, 0.00154365],
[ 0.00455639, 0.00548025, 0.00636496, 0.00553156],
[ 0.00534874, 0.00714018, 0.00698772, 0.00492981],
[ 0.00480491, 0.00315387, 0.00463258, 0.00367912],
[-0.00418113, -0.00538486, -0.00376593, -0.00524791],
[ 0.00094419, 0.00176553, 0.00281446, 0.00227668],
[ 0.00790458, 0.00602065, 0.00821311, 0.00788086],
[ 0.00727708, 0.00890354, 0.01185306, 0.00880289],
[ 0.00073471, -0.00148558, 0.00091246, -0.00060734],
[ 0.01349254, 0.01021229, 0.01257114, 0.0108043 ],
[ 0.02372086, 0.02297548, 0.02319355, 0.01668251],
[ 0.02945446, 0.02684999, 0.02926323, 0.02902177],
[ 0.02817995, 0.02719994, 0.02788556, 0.02784433],
[ 0.03695807, 0.02922592, 0.03577212, 0.03119914],
[ 0.0326502 , 0.03730632, 0.03973274, 0.03421444],
[ 0.03706311, 0.03287932, 0.03660752, 0.03450554],
[ 0.03988483, 0.03439958, 0.03887289, 0.03616712],
[ 0.03017368, 0.03424257, 0.03532363, 0.02857177],
[ 0.03855516, 0.03455752, 0.03746458, 0.0367456 ],
[ 0.03832089, 0.0379483 , 0.03854882, 0.03796184],
[ 0.04709962, 0.04396957, 0.04603146, 0.04606553],
[ 0.04521404, 0.04305011, 0.04424058, 0.04147213],
[ 0.03685116, 0.0401923 , 0.04021992, 0.0358785 ],
[ 0.0474408 , 0.04420525, 0.04624286, 0.04387761],
[ 0.04987269, 0.04401355, 0.04889452, 0.04502898],
[ 0.02905999, 0.03601027, 0.03544558, 0.02994194],
[ 0.02782917, 0.02909307, 0.0291289 , 0.02555896],
[ 0.03008661, 0.02874931, 0.02939694, 0.02995715]])
array([[ 0. , 0. , 0. , 0. ],
[ 0.00178892, 0.00581323, 0.00405187, 0.00398177],
[ 0.00257908, 0.00747371, 0.0046732 , 0.00338094],
[ 0.00203675, 0.00348608, 0.00232347, 0.00213218],
[-0.00692453, -0.00505548, -0.00605573, -0.00678109],
[-0.00181333, 0.00209729, 0.00050953, 0.0007319 ],
[ 0.00512789, 0.00635381, 0.00589577, 0.00632745],
[ 0.00450212, 0.00923766, 0.00952735, 0.00724806],
[-0.00202223, -0.0011549 , -0.00138809, -0.00214768],
[ 0.01070045, 0.01054684, 0.01024379, 0.00924638],
[ 0.0209006 , 0.02331426, 0.02084178, 0.01511553],
[ 0.0266184 , 0.02719005, 0.02689751, 0.02743577],
[ 0.0253474 , 0.02754012, 0.02552301, 0.02626014],
[ 0.03410133, 0.02956677, 0.03339144, 0.02960979],
[ 0.02980533, 0.03764984, 0.03734295, 0.03262044],
[ 0.03420609, 0.03322137, 0.03422491, 0.03291109],
[ 0.03702003, 0.03474214, 0.03648508, 0.03457011],
[ 0.02733564, 0.03458508, 0.03294398, 0.02698647],
[ 0.03569402, 0.03490013, 0.03508 , 0.0351477 ],
[ 0.0354604 , 0.03829204, 0.03616176, 0.03636206],
[ 0.04421495, 0.0443153 , 0.0436272 , 0.04445326],
[ 0.04233456, 0.04339553, 0.04184044, 0.03986694],
[ 0.03399472, 0.04053678, 0.03782902, 0.03428193],
[ 0.04455518, 0.04455105, 0.04383812, 0.04226871],
[ 0.04698038, 0.04435929, 0.04648368, 0.04341831],
[ 0.02622502, 0.03635336, 0.03306565, 0.02835452],
[ 0.02499759, 0.02943387, 0.02676349, 0.0239783 ],
[ 0.0272488 , 0.02909 , 0.02703092, 0.02836971],
[ 0.03273546, 0.03418049, 0.03225781, 0.03323893]])
array([[ 0. , 0. , 0. , 0. ],
[ 0.00078875, 0.00165088, 0.00061881, -0.00059845],
[ 0.00024739, -0.0023137 , -0.00172143, -0.00184225],
[-0.00869789, -0.01080589, -0.01006681, -0.01072017],
[-0.00359582, -0.00369447, -0.00352804, -0.00323698],
[ 0.003333 , 0.00053746, 0.00183645, 0.00233638],
[ 0.00270835, 0.00340464, 0.00545338, 0.00325333],
[-0.00380434, -0.00692786, -0.00541801, -0.00610513],
[ 0.00889562, 0.00470625, 0.00616693, 0.00524373],
[ 0.01907755, 0.01739988, 0.01672215, 0.01108961],
[ 0.02478514, 0.02125327, 0.02275344, 0.02336099],
[ 0.02351641, 0.02160131, 0.02138448, 0.02219002],
[ 0.03225471, 0.02361625, 0.02922117, 0.02552638],
[ 0.02796638, 0.0316526 , 0.03315673, 0.02852509],
[ 0.03235928, 0.02724974, 0.03005128, 0.02881459],
[ 0.0351682 , 0.02876171, 0.03230232, 0.03046703],
[ 0.0255011 , 0.02860556, 0.02877551, 0.02291346],
[ 0.03384456, 0.02891879, 0.03090292, 0.03104233],
[ 0.03361136, 0.03229109, 0.03198031, 0.03225187],
[ 0.04235027, 0.03827955, 0.03941562, 0.04031098],
[ 0.04047324, 0.03736509, 0.03763607, 0.03574286],
[ 0.03214829, 0.03452286, 0.03364083, 0.03018 ],
[ 0.0426899 , 0.03851393, 0.03962568, 0.0381351 ],
[ 0.04511076, 0.03832328, 0.04226057, 0.03928014],
[ 0.02439246, 0.03036362, 0.02889669, 0.02427609],
[ 0.02316723, 0.02348412, 0.02261996, 0.01991722],
[ 0.02541442, 0.02314224, 0.02288631, 0.02429122],
[ 0.03089128, 0.02820331, 0.02809211, 0.02914113],
[ 0.04296258, 0.04005147, 0.04172941, 0.04219798]])
array([[ 0. , 0. , 0. , 0. ],
[-0.00054094, -0.00395805, -0.00233879, -0.00124455],
[-0.00947917, -0.01243625, -0.01067902, -0.01012779],
[-0.00438112, -0.00533654, -0.0041443 , -0.00264011],
[ 0.00254224, -0.00111159, 0.00121689, 0.00293658],
[ 0.00191809, 0.00175086, 0.00483158, 0.00385409],
[-0.00458948, -0.0085646 , -0.0060331 , -0.00550999],
[ 0.00810047, 0.00305033, 0.00554468, 0.00584567],
[ 0.01827438, 0.01572304, 0.01609338, 0.01169505],
[ 0.02397748, 0.01957008, 0.02212094, 0.02397378],
[ 0.02270975, 0.01991755, 0.02075283, 0.02280211],
[ 0.03144116, 0.02192917, 0.02858466, 0.02614047],
[ 0.02715621, 0.02995227, 0.03251779, 0.02914097],
[ 0.03154565, 0.02555666, 0.02941426, 0.02943065],
[ 0.03435235, 0.02706615, 0.03166391, 0.03108408],
[ 0.02469287, 0.02691025, 0.02813928, 0.02352599],
[ 0.03302975, 0.02722297, 0.03026537, 0.03165972],
[ 0.03279673, 0.03058971, 0.0313421 , 0.03286999],
[ 0.04152876, 0.03656829, 0.03877281, 0.04093392],
[ 0.03965321, 0.03565535, 0.03699436, 0.03636306],
[ 0.03133482, 0.0328178 , 0.0330016 , 0.03079687],
[ 0.04186812, 0.03680229, 0.03898275, 0.03875674],
[ 0.04428707, 0.03661196, 0.041616 , 0.03990246],
[ 0.02358511, 0.02866542, 0.02826039, 0.02488943],
[ 0.02236083, 0.02179725, 0.02198754, 0.02052795],
[ 0.02460626, 0.02145594, 0.02225372, 0.02490457],
[ 0.0300788 , 0.02650866, 0.0274563 , 0.02975738],
[ 0.04214059, 0.0383373 , 0.04108517, 0.04282206]])].
TypeError: only size-1 arrays can be converted to Python scalars
The above exception was the direct cause of the following exception:
问题的解决
- 问题非常明显出现在训练模型的数据generator,数据model的初始化和训练定义在run.py的train_model函数,其中链接到了提供数据的函数generate_train_batch(data_preprocessor.py)
- 函数一眼看过去没有问题,而且还十分贴心在LSTM窗口超过数据量时,传输的数据会比seq_len小:”# stop-condition for a smaller final batch if data doesn't divide evenly“
- 但此处正是问题所在,在跑第二个epoch的时候,数据正好在一开始就没了,模型没有数据,就报了一个奇怪的错
- 两个小修改:当数据不足时,提前结束进入下一个数据循环;用while True包裹整个函数,让数据源源不断:
def generate_train_batch(self, seq_len, batch_size, normalise):
'''Yield a generator of training data from filename on given list of cols split for train/test'''
while True: # to avoid generator lack of data
i = 0
while i < (self.len_train - seq_len):
x_batch = []
y_batch = []
if i + batch_size >= (self.len_train - seq_len):
break
for b in range(batch_size):
# if i >= (self.len_train - seq_len):
# # stop-condition for a smaller final batch if data doesn't divide evenly
# yield np.array(x_batch), np.array(y_batch)
x, y = self._next_window(i, seq_len, normalise)
x_batch.append(x)
y_batch.append(y)
i += 1
yield np.array(x_batch), np.array(y_batch)
- 问题解决:
[Model] Training Started
[Model] 2 epochs, 8 batch size, 29 batches per epoch
2025-03-09 22:30:46.705440: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:116] None of the MLIR optimization passes are enabled (registered 2)
2025-03-09 22:30:46.705698: I tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 1996205000 Hz
Epoch 1/2
29/29 [==============================] - 3s 31ms/step - loss: 0.0024
Epoch 2/2
29/29 [==============================] - 1s 27ms/step - loss: 0.0016
[Model] Training Completed.
后记
虽然这个代码已经有些古早了,例如里面获取股票信息的API已经不能用了,但整体而言,还是一个挺完整的项目
看到issue里有人问下载不了新数据,我这里也一并给出我这边的方案:使用新浪API
http://money.finance.sina.com.cn/quotes_service/api/json_v2.php/CN_MarketData.getKLineData?symbol=[市场][股票代码]&scale=[周期]&ma=no&datalen=[长度]
市场:sh为上证,sz为深成
周期:5、10、30、60
长度:最大为1023
获取的直接是json数据,不用像原作者一样转换文本,可直接保存csv文件。只不过为了统一其他代码,需要把json的每一个组数据的key改为原作者设定的Date,Code,Name,Open,Close,High,Low,Volume。另外,获取的json中没有Code和Name,这个直接自己硬加即可。
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels