184 lines
6.1 KiB
Python
184 lines
6.1 KiB
Python
import numpy as np
|
||
|
||
class State:
|
||
def __init__(self, state, directionFlag=None, parent=None, f=0):
|
||
"""八数码问题状态类
|
||
Args:
|
||
state: 3x3 numpy数组,表示当前状态
|
||
directionFlag: 禁止的移动方向(用于防止走回头路)
|
||
parent: 父状态节点
|
||
f: 评估函数值(f = g + h)
|
||
"""
|
||
self.state = state # 当前状态矩阵
|
||
self.direction = ['up', 'down', 'right', 'left'] # 可移动方向
|
||
if directionFlag: # 移除来源方向防止回退
|
||
self.direction.remove(directionFlag)
|
||
self.parent = parent # 父节点指针
|
||
self.f = f # 评估函数值
|
||
|
||
def getDirection(self):
|
||
"""获取当前允许的移动方向列表"""
|
||
return self.direction
|
||
|
||
def setF(self, f):
|
||
"""设置评估函数值"""
|
||
self.f = f
|
||
return
|
||
|
||
# 打印结果
|
||
def showInfo(self):
|
||
"""打印当前状态矩阵"""
|
||
for i in range(len(self.state)):
|
||
for j in range(len(self.state)):
|
||
print(self.state[i, j], end=' ')
|
||
print("\n")
|
||
print('->')
|
||
return
|
||
|
||
# 获取0点
|
||
def getZeroPos(self):
|
||
"""获取空白格(0)的位置坐标
|
||
Returns:
|
||
tuple: 包含空白格的行列坐标(x, y)
|
||
"""
|
||
postion = np.where(self.state == 0)
|
||
return postion
|
||
|
||
# 曼哈顿距离 f = g + h,g=1,如果用宽度优先的评估函数可以不调用该函数
|
||
def getFunctionValue(self):
|
||
"""计算曼哈顿距离评估值
|
||
Returns:
|
||
int: 当前状态到目标状态的评估值(f = 移动步数 + 启发函数值)
|
||
"""
|
||
cur_node = self.state.copy()
|
||
fin_node = self.answer.copy()
|
||
dist = 0 # 曼哈顿距离总和
|
||
N = len(cur_node) # 矩阵维度(3)
|
||
|
||
# 遍历每个格子计算距离
|
||
for i in range(N):
|
||
for j in range(N):
|
||
if cur_node[i][j] != fin_node[i][j]:
|
||
index = np.argwhere(fin_node == cur_node[i][j]) # 找到对应数字在目标状态的位置
|
||
x = index[0][0] # 目标x坐标
|
||
y = index[0][1] # 目标y坐标
|
||
dist += (abs(x - i) + abs(y - j)) # 累加曼哈顿距离
|
||
return dist + 1 # 总距离+当前步数(g值固定为1)
|
||
|
||
def nextStep(self):
|
||
"""生成所有可能的下一个合法状态
|
||
Returns:
|
||
State: 评估函数值最小的下一个状态
|
||
"""
|
||
if not self.direction: # 无可用移动方向
|
||
return []
|
||
subStates = [] # 子状态列表
|
||
boarder = len(self.state) - 1 # 矩阵边界索引
|
||
|
||
# 获取0点位置
|
||
x, y = self.getZeroPos()
|
||
x, y = x[0], y[0] # 转换为标量值
|
||
|
||
# 向左移动
|
||
if 'left' in self.direction and y > 0:
|
||
s = self.state.copy()
|
||
tmp = s[x, y - 1]
|
||
s[x, y - 1] = s[x, y]
|
||
s[x, y] = tmp
|
||
news = State(s, directionFlag='right', parent=self)
|
||
news.setF(news.getFunctionValue())
|
||
subStates.append(news)
|
||
|
||
# 向上移动
|
||
if 'up' in self.direction and x > 0:
|
||
s = self.state.copy()
|
||
tmp = s[x - 1, y]
|
||
s[x - 1, y] = s[x, y]
|
||
s[x, y] = tmp
|
||
news = State(s, directionFlag='down', parent=self)
|
||
news.setF(news.getFunctionValue())
|
||
subStates.append(news)
|
||
|
||
# 向下移动
|
||
if 'down' in self.direction and x < boarder:
|
||
s = self.state.copy()
|
||
tmp = s[x + 1, y]
|
||
s[x + 1, y] = s[x, y]
|
||
s[x, y] = tmp
|
||
news = State(s, directionFlag='up', parent=self)
|
||
news.setF(news.getFunctionValue())
|
||
subStates.append(news)
|
||
|
||
# 向右移动
|
||
if 'right' in self.direction and y < boarder:
|
||
s = self.state.copy()
|
||
tmp = s[x, y + 1]
|
||
s[x, y + 1] = s[x, y]
|
||
s[x, y] = tmp
|
||
news = State(s, directionFlag='left', parent=self)
|
||
news.setF(news.getFunctionValue())
|
||
subStates.append(news)
|
||
|
||
# 返回F值最小的下一个点
|
||
subStates.sort(key=compareNum)
|
||
return subStates[0]
|
||
|
||
# A* 迭代
|
||
def solve(self):
|
||
"""执行A*算法求解路径
|
||
Returns:
|
||
list: 从初始状态到目标状态的路径节点列表
|
||
"""
|
||
# openList(待访问节点列表)
|
||
openTable = []
|
||
# closeList(已访问节点列表)
|
||
closeTable = []
|
||
openTable.append(self) # 初始节点加入open表
|
||
|
||
while len(openTable) > 0: # 当存在待访问节点时
|
||
# 取出评估值最小的节点
|
||
n = openTable.pop(0)
|
||
# 加入已访问列表
|
||
closeTable.append(n)
|
||
# 生成子节点
|
||
subStates = n.nextStep()
|
||
path = []
|
||
|
||
# 判断是否到达目标状态
|
||
if (subStates.state == subStates.answer).all():
|
||
# 回溯路径
|
||
while subStates.parent and subStates.parent != originState:
|
||
path.append(subStates.parent)
|
||
subStates = subStates.parent
|
||
path.reverse() # 反转得到正序路径
|
||
return path
|
||
|
||
# 将子节点加入open表
|
||
openTable.append(subStates)
|
||
else:
|
||
return None, None
|
||
|
||
def compareNum(state):
|
||
"""状态比较函数(用于排序)"""
|
||
return state.f
|
||
|
||
if __name__ == '__main__':
|
||
# 初始化起始状态(0代表空格)
|
||
originState = State(np.array([[1, 5, 3], [2, 4, 6], [7, 0, 8]]))
|
||
# 设置类属性:目标状态
|
||
State.answer = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 0]])
|
||
|
||
# 创建初始状态对象
|
||
s1 = State(state=originState.state)
|
||
# 求解路径
|
||
path = s1.solve()
|
||
|
||
# 输出结果
|
||
if path:
|
||
print("解决方案路径:")
|
||
for node in path:
|
||
node.showInfo()
|
||
print("目标状态:")
|
||
print(State.answer)
|
||
print("总步数:%d" % len(path))
|