import Data.List

type Pr = (Int, Int)

(!?) :: [[a]] -> Pr -> Maybe a
grid !? (row, col)
  | row < 0                     = Nothing
  | row >= length grid          = Nothing
  | col < 0                     = Nothing
  | col >= (length $ head grid) = Nothing
  | otherwise                   = Just (grid !! row !! col)

findChs1 :: Char -> String -> Int -> Int -> [Pr] -> [Pr]
findChs1 _  []    _ _ acc = acc
findChs1 ch (h:t) j i acc
  | ch == h   = findChs1 ch t j (i + 1) (acc ++ [(j,i)])
  | otherwise = findChs1 ch t j (i + 1) acc

findChs :: Char -> [String] -> Int -> [Pr] -> [Pr]
findChs _  []    _ acc = acc
findChs ch (h:t) j acc = findChs ch t (j + 1) (acc ++ findChs1 ch h j 0 [])

checkCh :: Char -> [String] -> Pr -> Pr -> [(Pr, Pr)]
checkCh ch grid coo@(row, col) dir@(dr, dc)
  | grid !? (row + dr, col + dc) == Just ch = [(coo, dir)]
  | otherwise = []

findMCorners :: [String] -> Pr -> [(Pr, Pr)]
findMCorners grid coo =
  let checkM = checkCh 'M' grid coo in
    checkM (-1, -1) ++
    checkM (-1,  1) ++
    checkM ( 1, -1) ++
    checkM ( 1,  1)

findMNeighs :: [String] -> Pr -> [(Pr, Pr)]
findMNeighs grid coo =
  let checkM = checkCh 'M' grid coo in
    findMCorners grid coo ++
    checkM (-1,  0) ++
    checkM ( 0, -1) ++
    checkM ( 0,  1) ++
    checkM ( 1,  0)

testXmas :: [String] -> (Pr, Pr) -> Bool
testXmas grid ((row, col), (dr, dc)) =
  grid !? (row + 2 * dr, col + 2 * dc) == Just 'A' &&
  grid !? (row + 3 * dr, col + 3 * dc) == Just 'S'

testMas :: [String] -> (Pr, Pr) -> Bool
testMas grid ((row, col), (dr, dc)) =
  grid !? (row - dr, col - dc) == Just 'S'

car :: (a, a) -> a
car (a, _) = a

findReps :: [Pr] -> [Pr] -> [Pr]
findReps acc [] = acc
findReps acc (h:t)
  | h `elem` t = findReps (h:acc) t
  | otherwise = findReps acc t

part1 grid = 
  length $ filter (testXmas grid) $
    concat $ map (findMNeighs grid) $
    findChs 'X' grid 0 []

part2 grid =
  length $ findReps [] $
    map (car) $ filter (testMas grid) $
    concat $ map (findMCorners grid) $
    findChs 'A' grid 0 []

main = do
  file <- readFile "day04.input"
  let lns = lines file
  putStrLn ("Part 1: " ++ (show $ part1 lns))
  putStrLn ("Part 2: " ++ (show $ part2 lns))