Files
2025-04-03 18:28:09 +08:00

184 lines
6.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import numpy as np
class State:
def __init__(self, state, directionFlag=None, parent=None, f=0):
"""八数码问题状态类
Args:
state: 3x3 numpy数组,表示当前状态
directionFlag: 禁止的移动方向(用于防止走回头路)
parent: 父状态节点
f: 评估函数值(f = g + h)
"""
self.state = state # 当前状态矩阵
self.direction = ['up', 'down', 'right', 'left'] # 可移动方向
if directionFlag: # 移除来源方向防止回退
self.direction.remove(directionFlag)
self.parent = parent # 父节点指针
self.f = f # 评估函数值
def getDirection(self):
"""获取当前允许的移动方向列表"""
return self.direction
def setF(self, f):
"""设置评估函数值"""
self.f = f
return
# 打印结果
def showInfo(self):
"""打印当前状态矩阵"""
for i in range(len(self.state)):
for j in range(len(self.state)):
print(self.state[i, j], end=' ')
print("\n")
print('->')
return
# 获取0点
def getZeroPos(self):
"""获取空白格(0)的位置坐标
Returns:
tuple: 包含空白格的行列坐标(x, y)
"""
postion = np.where(self.state == 0)
return postion
# 曼哈顿距离 f = g + h,g=1,如果用宽度优先的评估函数可以不调用该函数
def getFunctionValue(self):
"""计算曼哈顿距离评估值
Returns:
int: 当前状态到目标状态的评估值(f = 移动步数 + 启发函数值)
"""
cur_node = self.state.copy()
fin_node = self.answer.copy()
dist = 0 # 曼哈顿距离总和
N = len(cur_node) # 矩阵维度(3
# 遍历每个格子计算距离
for i in range(N):
for j in range(N):
if cur_node[i][j] != fin_node[i][j]:
index = np.argwhere(fin_node == cur_node[i][j]) # 找到对应数字在目标状态的位置
x = index[0][0] # 目标x坐标
y = index[0][1] # 目标y坐标
dist += (abs(x - i) + abs(y - j)) # 累加曼哈顿距离
return dist + 1 # 总距离+当前步数(g值固定为1
def nextStep(self):
"""生成所有可能的下一个合法状态
Returns:
State: 评估函数值最小的下一个状态
"""
if not self.direction: # 无可用移动方向
return []
subStates = [] # 子状态列表
boarder = len(self.state) - 1 # 矩阵边界索引
# 获取0点位置
x, y = self.getZeroPos()
x, y = x[0], y[0] # 转换为标量值
# 向左移动
if 'left' in self.direction and y > 0:
s = self.state.copy()
tmp = s[x, y - 1]
s[x, y - 1] = s[x, y]
s[x, y] = tmp
news = State(s, directionFlag='right', parent=self)
news.setF(news.getFunctionValue())
subStates.append(news)
# 向上移动
if 'up' in self.direction and x > 0:
s = self.state.copy()
tmp = s[x - 1, y]
s[x - 1, y] = s[x, y]
s[x, y] = tmp
news = State(s, directionFlag='down', parent=self)
news.setF(news.getFunctionValue())
subStates.append(news)
# 向下移动
if 'down' in self.direction and x < boarder:
s = self.state.copy()
tmp = s[x + 1, y]
s[x + 1, y] = s[x, y]
s[x, y] = tmp
news = State(s, directionFlag='up', parent=self)
news.setF(news.getFunctionValue())
subStates.append(news)
# 向右移动
if 'right' in self.direction and y < boarder:
s = self.state.copy()
tmp = s[x, y + 1]
s[x, y + 1] = s[x, y]
s[x, y] = tmp
news = State(s, directionFlag='left', parent=self)
news.setF(news.getFunctionValue())
subStates.append(news)
# 返回F值最小的下一个点
subStates.sort(key=compareNum)
return subStates[0]
# A* 迭代
def solve(self):
"""执行A*算法求解路径
Returns:
list: 从初始状态到目标状态的路径节点列表
"""
# openList(待访问节点列表)
openTable = []
# closeList(已访问节点列表)
closeTable = []
openTable.append(self) # 初始节点加入open表
while len(openTable) > 0: # 当存在待访问节点时
# 取出评估值最小的节点
n = openTable.pop(0)
# 加入已访问列表
closeTable.append(n)
# 生成子节点
subStates = n.nextStep()
path = []
# 判断是否到达目标状态
if (subStates.state == subStates.answer).all():
# 回溯路径
while subStates.parent and subStates.parent != originState:
path.append(subStates.parent)
subStates = subStates.parent
path.reverse() # 反转得到正序路径
return path
# 将子节点加入open表
openTable.append(subStates)
else:
return None, None
def compareNum(state):
"""状态比较函数(用于排序)"""
return state.f
if __name__ == '__main__':
# 初始化起始状态(0代表空格)
originState = State(np.array([[1, 5, 3], [2, 4, 6], [7, 0, 8]]))
# 设置类属性:目标状态
State.answer = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 0]])
# 创建初始状态对象
s1 = State(state=originState.state)
# 求解路径
path = s1.solve()
# 输出结果
if path:
print("解决方案路径:")
for node in path:
node.showInfo()
print("目标状态:")
print(State.answer)
print("总步数:%d" % len(path))