Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 11 additions & 19 deletions src/ai/backend/manager/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,16 +664,20 @@ async def raft_ctx(root_ctx: RootContext) -> AsyncIterator[None]:
raft_configs = root_ctx.local_config.get("raft")
raft_cluster_configs = root_ctx.raft_cluster_config

if raft_configs is not None:
assert raft_cluster_configs is not None
# Only the first process will be used as RaftNode
if process_index.get() == 0 and raft_configs is not None:
assert (
raft_cluster_configs is not None
), "Raft cluster config should be provided when raft feature is enabled."

other_peers = [{**peer, "myself": False} for peer in raft_cluster_configs["peers"]["other"]]
my_peers = [{**peer, "myself": True} for peer in raft_cluster_configs["peers"]["myself"]]
all_peers = sorted([*other_peers, *my_peers], key=lambda x: x["node-id"])
my_peer = [{**peer, "myself": True} for peer in raft_cluster_configs["peers"]["myself"]]
assert len(my_peer) == 1, '"peers.myself" should have only one entry!'
all_peers = sorted([*other_peers, *my_peer], key=lambda x: x["node-id"])

assert (
root_ctx.local_config["manager"]["num-proc"] >= len(my_peers)
), "The number of raft peers (myself), should be greater than or equal to the number of processes"
node_id_offset = next((idx for idx, item in enumerate(all_peers) if item["myself"]), None)
assert node_id_offset is not None
node_id = node_id_offset + 1

initial_peers = Peers({
int(peer_config["node-id"]): Peer(
Expand Down Expand Up @@ -709,10 +713,6 @@ async def raft_ctx(root_ctx: RootContext) -> AsyncIterator[None]:
raft_config=raft_core_config,
)

node_id_offset = next((idx for idx, item in enumerate(all_peers) if item["myself"]), None)
assert node_id_offset is not None, '"peers.myself" not found in initial_peers!'
node_id = node_id_offset + process_index.get() + 1

raft_addr = initial_peers.get(node_id).get_addr()

store = HashStore()
Expand Down Expand Up @@ -836,14 +836,6 @@ def init_lock_factory(root_ctx: RootContext) -> DistributedLockFactory:
root_ctx.shared_config.etcd,
lifetime=min(lifetime_hint * 2, lifetime_hint + 30),
)
case "etcetra":
from ai.backend.common.lock import EtcetraLock

return lambda lock_id, lifetime_hint: EtcetraLock(
str(lock_id),
root_ctx.shared_config.etcetra_etcd,
lifetime=min(lifetime_hint * 2, lifetime_hint + 30),
)
case other:
raise ValueError(f"Invalid lock backend: {other}")

Expand Down