#!/opt/ghc/bin/runhaskell

-- Coypright 2007 Adam Peacock adpeac@gmail.com

import Prelude hiding ( catch )
import Control.Exception
	( Exception )

import System.IO.Error 
	( illegalOperationErrorType
	, mkIOError
	, catch )
	
import System
	( ExitCode (..)
	, getArgs 
	, system )

import XenConfigParser
	( Host (..) 
	, maybeHostFromFile )

import Maybe
	( fromJust )

main :: IO ()
main = do
	args <- getArgs
	verifyArgs args
	maybeHosts <- mapM maybeHostFromFile args
	verifyHosts maybeHosts
	let hosts = map fromJust maybeHosts
		in catch (do 
				  iptables_save
				  iptables_flush
				  iptables_INPUT hosts
				  iptables_FORWARD hosts
				  iptables_OUTPUT) rollBack 

-- programme must have at least one argument
verifyArgs :: [String] -> IO ()
verifyArgs [] = error "Usage: ./createXenFirewallRules <dom0ConfigFile> [<VMConfigFile>...]"
verifyArgs _ = return ()

verifyHosts :: [Maybe Host] -> IO ()
verifyHosts maybeHosts 
	| Nothing `elem` maybeHosts = error "Non-valid host file(s). See README for an example." 
	| otherwise = return ()

-- roll the rules back to those saved 
rollBack :: IOError -> IO ()
rollBack e  = do
	putStrLn $ show e
	putStrLn "\nError in rules, rolling back\n"
	system "/etc/init.d/iptables reload" 
	return ()

-- generated and apply the rules for the INPUT chain
iptables_INPUT :: [Host] -> IO ()
iptables_INPUT hosts =
	if length dom0Singleton /= 1 
	   then ioError dom0fileError
	   else do
			applyingRules input
	   		iptables_tcpBadFlags input
			iptables_enableRelatedEstablished input
			iptables_enablePortsOnEth0 (head dom0Singleton)
			iptables_enableEverthingForLO
			iptables_allowICMP input
			iptables_dropRestWithLog input
			successfullyApplied input
	where
	dom0Singleton = filter (\h -> vifname h == dom0Physdev) hosts
	dom0Physdev = "vif0.0"

-- generated and apply the rules for the FORWARD chain
iptables_FORWARD :: [Host] -> IO ()
iptables_FORWARD hosts = do
	applyingRules forward
	iptables_tcpBadFlags forward
	iptables_enableRelatedEstablished forward
	iptables_enableAllVMOut hosts
	mapM (iptables_enableHostPorts forward) hosts
	iptables_allowICMP forward
	iptables_dropRestWithLog forward
	successfullyApplied forward

-- generated and apply the rules for the OUTPUT chain
iptables_OUTPUT :: IO ()
iptables_OUTPUT = do
	applyingRules output
	iptables_enableEverything output
	successfullyApplied output

-- chain names
input, forward, output :: String
input = "INPUT"
forward = "FORWARD"
output = "OUTPUT"

iptables_save, iptables_flush :: IO ()
iptables_save = putStrLn "Saving current rules:\n " >> systemThrowIfError "/etc/init.d/iptables save" 
iptables_flush = putStrLn "\nFlushing chains:\n" >> systemThrowIfError "iptables -F" 

-- drop tcp packets with these bad flags
iptables_tcpBadFlags :: String -> IO ()
iptables_tcpBadFlags chain = mapM_ badFlags tcpBadFlagsList
	where
	badFlags flags = systemThrowIfError (iptables_A chain "-p tcp -m tcp --tcp-flags " ++ flags ++ " -j DROP")
	tcpBadFlagsList = ["FIN,SYN,RST,PSH,ACK,URG FIN,SYN,RST,PSH,ACK,URG",
					   "FIN,SYN,RST,PSH,ACK,URG FIN,SYN,RST,ACK,URG",
					   "FIN,SYN,RST,PSH,ACK,URG FIN,PSH,URG", 
					   "FIN,SYN,RST,PSH,ACK,URG NONE",
					   "SYN,RST SYN,RST",
					   "FIN,SYN FIN,SYN"]

-- drop the packet after writing it in the logs
iptables_dropRestWithLog :: String -> IO ()
iptables_dropRestWithLog chain = 
	systemThrowIfError (iptables_A chain "-j LOG --log-prefix \"" ++ chain ++ " \" && " ++
						iptables_A chain "-j DROP")

-- drop anything
iptables_dropRest :: String -> IO ()
iptables_dropRest chain =
	systemThrowIfError (iptables_A chain "-j DROP")

-- 0 Echo Reply
-- 3 Destination Unreachable
-- 8 Echo
-- 11 Time Exceeded
iptables_allowICMP :: String -> IO ()
iptables_allowICMP chain = 
	mapM_ (systemThrowIfError.iptables_enableICMPType) [0,3,8,11]
	where 
	iptables_enableICMPType icmpType = 
		iptables_A chain "-p icmp -m icmp --icmp-type " ++ show icmpType ++ limitAndACCEPTStr 

iptables_enableHostPorts :: String -> Host -> IO ()
iptables_enableHostPorts chain (Host _ deviceName tcpPorts udpPorts) = 
	mapM_ (enableTCPPort chain deviceName) tcpPorts >> mapM_ (enableUDPPort chain deviceName) udpPorts

iptables_enableEverything :: String -> IO ()
iptables_enableEverything chain = 
	systemThrowIfError (iptables_A chain "-j ACCEPT")

enableTCPPort, enableUDPPort :: String -> String -> Int -> IO ()
enableTCPPort = enablePort "tcp"
enableUDPPort = enablePort "udp"

enablePort :: String -> String -> String -> Int -> IO ()
enablePort protocol chain deviceName port = 
	systemThrowIfError (iptables_A chain "-p " ++ protocol ++ 
						" --dport " ++ show port ++
						" -m physdev --physdev-out " ++ deviceName ++ limitAndACCEPTStr)

iptables_enableRelatedEstablished :: String -> IO ()
iptables_enableRelatedEstablished chain = 
	systemThrowIfError (iptables_A chain "-m state --state RELATED,ESTABLISHED -j ACCEPT")

iptables_enablePortsOnEth0 :: Host -> IO ()
iptables_enablePortsOnEth0 host =
	mapM_ (enablePortEth0 "tcp") (portsIncomingOpenTCP host) >> mapM_ (enablePortEth0 "udp") (portsIncomingOpenUDP host) 
	where 
	enablePortEth0 protocol port =
		systemThrowIfError (iptables_A input "-i eth0 -p " ++ protocol ++ " --dport " ++ show port ++ limitAndACCEPTStr)

-- with iptables, you should think of you self sitting on the xenbr0,
-- therefore, to enable all VM trafic out, this is coming in the physdev
iptables_enableAllVMOut :: [Host] -> IO ()
iptables_enableAllVMOut hosts =	mapM_ enableVMOut hosts
	where enableVMOut h = systemThrowIfError (iptables_A forward "-m physdev --physdev-in " ++ 
											  vifname h ++ " -j ACCEPT")

iptables_A :: String -> String -> String
iptables_A chain acl = 
	"iptables -A " ++ chain ++ " " ++ acl

iptables_enableEverthingForLO :: IO ()
iptables_enableEverthingForLO = 
	systemThrowIfError (iptables_A input " -i lo -m comment --comment 'Accept everything on loop back (lo)' -j ACCEPT")

systemThrowIfError :: String -> IO ()
systemThrowIfError command = do
	putStrLn command
	system command >>= \exitCode -> if exitCode /= ExitSuccess then ioError (iptablesError exitCode) else return ()
	
iptablesError :: ExitCode -> IOError
iptablesError e = mkIOError illegalOperationErrorType (show e ++ " - Error in iptables command") Nothing Nothing

dom0fileError :: IOError
dom0fileError = mkIOError illegalOperationErrorType "dom0 host isn't provided" Nothing Nothing

limitAndACCEPTStr :: String
limitAndACCEPTStr = " -m limit --limit " ++ show acceptRate ++ "/second -j ACCEPT"

-- the amount of connection you which to accept per second
acceptRate :: Int
acceptRate = 3

applyingRules, successfullyApplied :: String -> IO ()
applyingRules chain = putStrLn $ "\nApplying " ++ chain ++ " chain rules:\n"
successfullyApplied chain = putStrLn $ '\n' : chain ++ " chain rules successfully applied."