diff --git a/CHANGELOG.md b/CHANGELOG.md index 0289a144e..4a93c7eb9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,11 @@ +-1.0.7 (September 2023): +- CPU and memory profilers thanks to @danieltherealyang +- Check dns queries and answers for whitelisted IPs and domains +- Add AID flow hash to all conn.log flows, which is a combination of community_id and the flow's timestamp +- Sqlite database improvements and better error handling +- Add support for exporting Slips alerts to a sqlite db + + -1.0.6 (June 2023): - Store flows in SQLite database in the output directory instead of redis. - 55% RAM usage decrease. diff --git a/README.md b/README.md index d788ffa6e..cad9b0f14 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@

-Slips v1.0.6 +Slips v1.0.7

[Documentation](https://stratospherelinuxips.readthedocs.io/en/develop/) — [Features](https://stratospherelinuxips.readthedocs.io/en/develop/features.html) — [Installation](#installation) — [Authors](#people-involved) — [Contributions](#contribute-to-slips) @@ -32,7 +32,7 @@ Slips v1.0.6 # Slips: Behavioral Machine Learning-Based Intrusion Prevention System -Slips is a behavioral intrusion prevention system that uses machine learning to detect malicious behaviors in network traffic. Slips focus on targeted attacks, detection of command and control channels, and providing a good visualization for the analyst. It can analyze network traffic in real-time, network captures such as pcap files, and network flows produced by Suricata, Zeek/Bro, and Argus. Slips processes the input data, analyzes it, and highlights suspicious behavior that needs the analyst's attention. +Slips is a powerful endpoint behavioral intrusion prevention and detection system that uses machine learning to detect malicious behaviors in network traffic. Slips can work with network traffic in real-time, pcap files, and network flows from popular tools like Suricata, Zeek/Bro, and Argus. Slips threat detection is based on a combination of machine learning models trained to detect malicious behaviors, 40+ threat intelligence feeds and expert heuristics. Slips gathers evidence of malicious behavior and uses extensively trained thresholds to trigger alerts when enough evidence is accumulated. diff --git a/VERSION b/VERSION index ece61c601..f9cbc01ad 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.0.6 \ No newline at end of file +1.0.7 \ No newline at end of file diff --git a/config/slips.conf b/config/slips.conf index a5ef14ee1..61c94988c 100644 --- a/config/slips.conf +++ b/config/slips.conf @@ -164,8 +164,9 @@ popup_alerts = no # [3] Generic Confs for the modules or to process the modules [modules] # List of modules to ignore. By default we always ignore the template! do not remove it from the list -disable = [template, ensembling] -# Names of other modules that you can disable: ensembling, threatintelligence, blocking, +disable = [template, ensembling, rnnccdetection] +# Names of other modules that you can disable (they all should be lowercase with no special characters): +# ensembling, threatintelligence, blocking, # networkdiscovery, timeline, virustotal, rnnccdetection, flowmldetection, updatemanager # For each line in timeline file there is a timestamp. @@ -376,7 +377,45 @@ UID = 0 GID = 0 #################### -# [11] enable or disable p2p for slips + +[Profiling] + +# [11] CPU profiling + +# enable cpu profiling [yes,no] +cpu_profiler_enable = no + +# Available options are [dev,live] +# dev for deterministic profiling. this will give precise information about the CPU usage +# throughout the program runtime. This module cannot give live updates +# live mode is for sampling data stream. To track the function stack in real time. it is accessible from web interface +cpu_profiler_mode = dev + +# profile all subprocesses in dev mode [yes,no]. +cpu_profiler_multiprocess = yes + +# set number of tracer entries (dev mode only) +cpu_profiler_dev_mode_entries = 1000000 + +# set maximum output lines (live mode only) +cpu_profiler_output_limit = 20 + +# set the wait time between sampling sequences in seconds (live mode only) +cpu_profiler_sampling_interval = 20 + +# [12] Memory Profiling + +# enable memory profiling [yes,no] +memory_profiler_enable = no + +# set profiling mode [dev,live] +memory_profiler_mode = live + +# profile all subprocesses [yes,no] +memory_profiler_multiprocess = yes + +#################### +# [13] enable or disable p2p for slips [P2P] # create p2p.log with additional info about peer communications? yes or no diff --git a/conftest.py b/conftest.py index 2982ca2be..dca27f257 100644 --- a/conftest.py +++ b/conftest.py @@ -19,7 +19,7 @@ @pytest.fixture -def mock_db(): +def mock_rdb(): # Create a mock version of the database object with patch('slips_files.core.database.database_manager.DBManager') as mock: yield mock.return_value diff --git a/docker/macosm1-image/requirements-macos-m1-docker.txt b/docker/macosm1-image/requirements-macos-m1-docker.txt index 03bce0a46..be68e41d3 100644 --- a/docker/macosm1-image/requirements-macos-m1-docker.txt +++ b/docker/macosm1-image/requirements-macos-m1-docker.txt @@ -36,4 +36,8 @@ wheel flask tld tqdm -termcolor \ No newline at end of file +communityid +termcolor +memray +viztracer +yappi \ No newline at end of file diff --git a/docs/features.md b/docs/features.md index c790b1348..f861d3455 100644 --- a/docs/features.md +++ b/docs/features.md @@ -1017,8 +1017,47 @@ and ICMP-AddressMaskScan based on the icmp type We detect a scan every threshold. So we generate an evidence when there is 5,10,15, .. etc. ICMP established connections to different IPs. +### CPU Profiling +Slips is shipped with its own tool for CPU Profiling, it can be found it ```slips_files/common/cpu_profiler.py``` +CPU Profiling supports 2 modes: live and development mode + +#### Live mode: +The main purpose of this mode it to show live CPU stats in the web interface. +"live" mode publishes updates during the runtime of the program to the redis channel 'cpu_profile' so that the web interface can use them + +#### Development mode: + + Setting the mode to "dev" outputs a JSON file of the CPU usage at the end of the program run. + It is recommended to only use dev mode for static file inputs (pcaps, suricata files, binetflows, etc.) instead of interface and growing zeek dirs, because longer runs result in profiling data loss and not everything will get recorded. +The JSON file created in this mode is placed in the output dir of the current run and can be viewed by running the following command + +```vizviewer results.json``` + +then going to http://127.0.0.1:9001/ in your browser for seeing the visualizations of the CPU usage + + +Options to enable cpu profiling can be found under the [Profiling] section of the ```slips.conf``` file. +```cpu_profiler_enable``` set to "yes" enables cpu profiling, or "no" to disable it. +```cpu_profiler_mode``` can be set to "live" or "dev". Setting to +```cpu_profiler_multiprocess``` can be set to "yes" or "no" and only affects the dev mode profiling. If set to "yes" then all processes will be profiled. If set to "no" then only the main process (slips.py) will be profiled. +```cpu_profiler_output_limit``` is set to an integer value and only affects the live mode profiling. This option sets the limit on the number of processes output for live mode profiling updates. +```cpu_profiler_sampling_interval``` is set to an integer value and only affects the live mode profiling. This option sets the duration in seconds of live mode sampling intervals. It is recommended to set this option greater than 10 seconds otherwise there won't be much useful information captured during sampling. + +### Memory Profiling +Memory profiling can be found in ```slips_files/common/memory_profiler.py``` + +Just like CPU profiling, it also has supports live and development mode. +Set ```memory_profiler_enable``` to ```yes``` to enable this feature. +Set ```memory_profiler_mode``` to ```live``` to use live mode or ```dev``` to use development mode profiling. + +#### Live Mode +This mode shows memory usage stats during the runtime of the program. +```memory_profiler_multiprocess``` controls whether live mode tracks all processes or only the main process. If set to no, the program will wait for you to connect from a different terminal using the command ```memray live ```, where port_number is 5000 by default. After connection, the program will continue with its run and the terminal that is connected will receive a feed of the memory statistics. If set to yes, the redis channel "memory_profile" can be used to set pid of the process to be tracked. Only a single process can be tracked at a time. The interface is cumbersome to use from the command line so multiprocess live profiling is intended to be used primarily from the web interface. + +#### Development Mode +When enabled, the profiler will output the profile data into the output directory. The data will be in the ```memoryprofile``` directory of the output directory of the run. Each process during the run of the program will have an associated binary file. Each of the generated binaries will automatically be converted to viewable html files, with each process converted to a flamegraph and table format. All generated files will be denoted by their PID. --- diff --git a/docs/images/alternate_mem_profiler_testing.png b/docs/images/alternate_mem_profiler_testing.png new file mode 100644 index 000000000..037133834 Binary files /dev/null and b/docs/images/alternate_mem_profiler_testing.png differ diff --git a/docs/images/cpu-profiler-config.png b/docs/images/cpu-profiler-config.png new file mode 100644 index 000000000..136a0f0db Binary files /dev/null and b/docs/images/cpu-profiler-config.png differ diff --git a/docs/images/cpu-profiler-live-mode.png b/docs/images/cpu-profiler-live-mode.png new file mode 100644 index 000000000..5edddba46 Binary files /dev/null and b/docs/images/cpu-profiler-live-mode.png differ diff --git a/docs/images/cpu-profiler-live-results.png b/docs/images/cpu-profiler-live-results.png new file mode 100644 index 000000000..386493df3 Binary files /dev/null and b/docs/images/cpu-profiler-live-results.png differ diff --git a/docs/images/cpu-profiler-starting.png b/docs/images/cpu-profiler-starting.png new file mode 100644 index 000000000..db3c1b88e Binary files /dev/null and b/docs/images/cpu-profiler-starting.png differ diff --git a/docs/images/cpu-profiler-termination.png b/docs/images/cpu-profiler-termination.png new file mode 100644 index 000000000..3de6fe240 Binary files /dev/null and b/docs/images/cpu-profiler-termination.png differ diff --git a/docs/images/cpu=profiler-results.png b/docs/images/cpu=profiler-results.png new file mode 100644 index 000000000..e3068967e Binary files /dev/null and b/docs/images/cpu=profiler-results.png differ diff --git a/docs/images/flamegraph.png b/docs/images/flamegraph.png new file mode 100644 index 000000000..afef2512c Binary files /dev/null and b/docs/images/flamegraph.png differ diff --git a/docs/images/live-mem-profiler.png b/docs/images/live-mem-profiler.png new file mode 100644 index 000000000..f9300d87d Binary files /dev/null and b/docs/images/live-mem-profiler.png differ diff --git a/docs/images/mem-profiler-ending.png b/docs/images/mem-profiler-ending.png new file mode 100644 index 000000000..9f281e2ab Binary files /dev/null and b/docs/images/mem-profiler-ending.png differ diff --git a/docs/images/mem-profiler-running.png b/docs/images/mem-profiler-running.png new file mode 100644 index 000000000..d59319652 Binary files /dev/null and b/docs/images/mem-profiler-running.png differ diff --git a/docs/images/mem-profiler-starting.png b/docs/images/mem-profiler-starting.png new file mode 100644 index 000000000..6f364e1e6 Binary files /dev/null and b/docs/images/mem-profiler-starting.png differ diff --git a/docs/images/mem-profiler-table-view.png b/docs/images/mem-profiler-table-view.png new file mode 100644 index 000000000..7faf4fd10 Binary files /dev/null and b/docs/images/mem-profiler-table-view.png differ diff --git a/docs/images/memory_multi_process_profiler_structure.png b/docs/images/memory_multi_process_profiler_structure.png new file mode 100644 index 000000000..3f5bafc91 Binary files /dev/null and b/docs/images/memory_multi_process_profiler_structure.png differ diff --git a/docs/images/memory_profiler_interface.png b/docs/images/memory_profiler_interface.png new file mode 100644 index 000000000..e3dcd1d0e Binary files /dev/null and b/docs/images/memory_profiler_interface.png differ diff --git a/docs/images/memory_profiler_structure.png b/docs/images/memory_profiler_structure.png new file mode 100644 index 000000000..8d08a06b0 Binary files /dev/null and b/docs/images/memory_profiler_structure.png differ diff --git a/docs/images/running-vizviewer.png b/docs/images/running-vizviewer.png new file mode 100644 index 000000000..0b7b4e51e Binary files /dev/null and b/docs/images/running-vizviewer.png differ diff --git a/docs/images/slips.gif b/docs/images/slips.gif index 505b63c03..4d13821c7 100644 Binary files a/docs/images/slips.gif and b/docs/images/slips.gif differ diff --git a/docs/images/testing_live_cpu_profiler.png b/docs/images/testing_live_cpu_profiler.png new file mode 100644 index 000000000..7ec3788fc Binary files /dev/null and b/docs/images/testing_live_cpu_profiler.png differ diff --git a/docs/images/testinig_mem_profiler_live_mode.png b/docs/images/testinig_mem_profiler_live_mode.png new file mode 100644 index 000000000..834f8cda0 Binary files /dev/null and b/docs/images/testinig_mem_profiler_live_mode.png differ diff --git a/docs/images/testinig_mem_profiler_live_mode_step3_1.png b/docs/images/testinig_mem_profiler_live_mode_step3_1.png new file mode 100644 index 000000000..f0ec497f9 Binary files /dev/null and b/docs/images/testinig_mem_profiler_live_mode_step3_1.png differ diff --git a/docs/images/testinig_mem_profiler_live_mode_step3_2.png b/docs/images/testinig_mem_profiler_live_mode_step3_2.png new file mode 100644 index 000000000..fc108e282 Binary files /dev/null and b/docs/images/testinig_mem_profiler_live_mode_step3_2.png differ diff --git a/docs/profiling_slips.md b/docs/profiling_slips.md new file mode 100644 index 000000000..84b448975 --- /dev/null +++ b/docs/profiling_slips.md @@ -0,0 +1,473 @@ +# CPU and Memory Profiling +GSoC 2023 SLIPS Final Report + +### Author +Daniel Yang (danielyangkang@gmail.com) +### Github +https://github.com/danieltherealyang/StratosphereLinuxIPS +### Repo Link +CPU Profiling: +https://github.com/stratosphereips/StratosphereLinuxIPS/pull/362 +Memory Profiling: +https://github.com/stratosphereips/StratosphereLinuxIPS/pull/388 + +## Overview +### Goal +This project aims to provide a robust and easy to use profiling framework for slips. The features created should have a simple interface and minimal involvement needed to get usable interactive data to find areas to improve the program. + +### Current State +The following features are provided for CPU and memory profiling. +Both will have live profiling mode and a development (dev) profiling mode. Live mode allows for real time data passing, primarily done through a redis channel. Dev mode aggregates the data for a single run and outputs profiling results. +CPU dev mode integrates viztracer to collect profiling data and create a visual output for the data. +CPU live mode uses yappi to collect data over a sampling interval of 20 seconds periodically outputting the data into the “cpu_profile” redis channel. +Memory dev and live mode both integrate memray to collect data. Dev mode collects all process data as bin files and outputs as html files. Each process’ data is in a separate file because there is currently no viable way, supported or unsupported by memray, to combine separate process data together. +Memory live mode is intended to be integrated into the web interface. Since only one process can be profiled at a time with live mode, the multiprocess.Process class is extended with additional functionality when the feature is enabled. The web interface should send the PID into the redis channel “memory_profile” to switch which process to profile. + +***Note***: Not recommended to run both profilers at the same time as of now. It is not tested thoroughly and could also cause issues with speed. + +### What’s left to do +1. The CPU and Memory live modes need to be integrated into the web interface so that there can be a panel that polls the respective redis channels and displays accordingly. The ideal way is for the CPU live profiler to have its own panel either just displaying the text in the redis channel or the table can get parsed into proper UI elements. The Memory live profiler should have a panel displaying a web terminal, where the output of the terminal command should be displayed in the browser. As proof of concept, think of VSCode’s integrated terminal using xterm.js. There also needs to be a way to see PIDs of currently running processes and switching which process is currently profiled. +2. The memory profiler dev mode outputs data in a really cumbersome way, putting each process into a separate file. Having tried to work around this myself, I find that there is currently no good way to aggregate separate files together and compare interprocess data. To fix this, we might need to consider contacting the developers of memray, the tool which was integrated into the memory profiling feature or create a private fork of the tool for customization to the needs of slips. +3. The live memory profiler is currently only one way communication, where the pid to get profiled gets sent into a redis channel for IPC and the profiler class manages everything internally. Since all of this is to get around the fact that memray does not have multiprocess live profiling built in, the web interface is going to need a way to receive updates on the status of the profiler so that it knows when to reconnect the web terminal client using “memray live ”. + + +## Usage +### CPU Profiler Dev Mode +#### Step 1 +To start with, go to slips.conf and make sure the feature is enabled. + + +and set number of tracer entries (dev mode only) to cpu_profiler_dev_mode_entries = 1000000 + +The first two settings cpu_profiler_enable and cpu_profiler_mode should be self explanatory. The cpu_profiler_multiprocess setting decides whether the profiler tracks all processes or only one. +If this setting is set to: “no” then only the main process that initiates slips gets tracked. + +The setting ```cpu_profiler_dev_mode_entries``` sets the size of the circular buffer for the profiler. If the runtime of the profiler is too long and the data at the start of the run gets lost due to the circular buffer filling up, increase this number. + + +### Step 2 + +Next, just run slips as normal. For an example command, run: +``` ./slips.py -e 1 -f dataset/test12-icmp-portscan.pcap ``` + +The program should run pretty similarly to how it normally runs with some small differences. At the beginning of the run, you should see + + + +this means that the profiler starts a recursive process to run the same command used to run slips but with the tracer context active. This is so that all the data can be captured. + +Then, at the end of the run, you should see: + +```Loading trace data from processes 0/17``` + +After this completes, you should see this at the bottom of the terminal + + + +### Step 3 + +Now we need to view the recorded data. +Just run the command at the bottom + +```vizviewer .json``` + + + +### Step 4 + +From there, go to your browser and open up the path http://localhost:9001 + + + +Use WASD to zoom in and move left/right. + + +### CPU Profiler Live Mode +#### Step 1 +Go to slips.conf and make sure the settings are set correctly. +```cpu_profiler_enable = yes``` +```cpu_profiler_mode = live``` + +You can also set maximum output lines (live mode only) to adjust profiler behavior. +```cpu_profiler_output_limit = 20``` + +and set the wait time between sampling sequences in seconds (live mode only) +```cpu_profiler_sampling_interval = 20``` + +The ```cpu_profiler_output_limit``` sets the number of lines for the updates and ```cpu_profiler_sampling_interval``` sets the number of seconds between profiler updates to the cpu_profile redis channel. + +#### Step 2 +Run slips.py as normal. The output shouldn’t be any different than if this feature were disabled. The only difference should be + + + +the **CPU Profiler Started** line in the output. + +If we print out the data getting sent to the “cpu_profile” redis channel, it should be in this format: + + + +### Memory Profiler Dev Mode +#### Step 1 + +The memory profiler settings are much simpler. +in slips.conf, first, enable memory profiling +```memory_profiler_enable = yes``` + + +and set profiling mode: + +```memory_profiler_mode = dev``` + +now, profile all subprocesses +```memory_profiler_multiprocess = yes``` + +#### Step 2 + +Now just run slips.py. You should see this at the start + + +After the run is done, you should see + + + +When all of the html files are finished generating, you should see + + + +#### Step 3 +Now navigate to the output directory under your run. There are two directories, ```flamegraph/``` and ```table/``` where the html files are located. +Under **flamegraph**, the opening the files should result in something like this + + + +At the top, there is a graph showing the allocated memory size over time. The slider can be used to adjust the time frame. The graph shows memory that was allocated after tracking started and not deallocated by the time tracking ended. Below that, there is a stack trace where the x axis is the proportion of memory allocated to the function. + +**Notes for reading the flamegraph**: the snapshots generated by memray flamegraph normally show you only a single point in time. +By default, that’s the point when the process’s heap memory usage was highest. + +Under the table directory, the files are much simpler. They just show a table of the memory allocations. + + + + + +### Memory Profiler Live Mode +#### Step 1 + +Go to ```slips.conf``` and use the following settings + +``` +memory_profiler_enable = yes +memory_profiler_mode = yes +memory_profiler_multiprocess = yes +``` + +### Step 2 +Running slips.py should result in the following output + + + +The output is almost the same as normal except for the “Memory Profiler Started” message and the “Child process started - PID: ”. The child process messages are an indication that the Process class was successfully patched and signal listeners are functioning correctly. + +### Step 3 +Now, in order to interact with the live profiler, send the pid into the “memory_profile” redis channel. After the message is sent, run “memray live 1234” to receive the profiling data. + +--- + +### Testing + +For dev mode, it is pretty simple to see if everything is working correctly. Just run and check if output is the same as the examples given. This section will be mainly about testing live mode so that the state of the profilers can get verified over the runtime of the program. There is no formal testing framework since the state of the profiler is different across runs. + +#### Testing CPU Profiler Live Mode + +##### Step 1 +Make sure the feature is enabled and the memory profiler is disabled. + +##### Step 2 +In cpu_profiler.py, add the print(stringio.getvalue()) in the following function. + +```python +def _sampling_loop(self): + stringio = io.StringIO() + while self.is_running: + # code omitted . . . + self.print() #prints to stringio, not stdout + print(stringio.getvalue()) #Add this line before self.db.publish() + self.db.publish('cpu_profile', stringio.getvalue()) +``` + + +##### Step 3 +Verify that you get profiler updates at regular intervals mixed into your regular terminal output. + + + + +#### Testing Memory Profiler Live Mode +##### Step 1 +Enable live memory profiling. + +##### Step 2 +Make sure the last two lines shown below are uncommented in the LiveMultiprocessProfiler class located in memory_profiler.py. + + + +##### Step 3 +Run slips.py like normal. +You should see some red text, indicating that the signal has been received. The text is output by the MultiprocessPatchMeta class located in memory_profiler.py. The red print statements can be used to debug the signal processing steps at each point in time. + + + +After a few seconds, the profiling data should pop up. + + + +#### Alternate Testing Memory Profiler Live Mode +##### Step 1 +As always, make sure the live memory profiler is enabled in slips.conf. + +##### Step 2 +In slips.py, uncomment the following line in start() for the Main class. +``` +self.cpu_profiler_init() +self.memory_profiler_init() +# uncomment line to see that memory profiler works correctly +# Should print out red text if working properly +self.memory_profiler_multiproc_test() # <- uncomment this line +``` + +If you want to see the profiling data and not just the red text indicating that IPC signals were received, uncomment the following line in memory_profiler_multiproc_test() also in the Main class. +``` +# Message passing +self.db.publish("memory_profile", processes[1].pid) # successful +subprocess.Popen(["memray", "live", "1234"]) # <- uncomment this line +time.sleep(5) # target_function will timeout and tracker will be cleared +self.db.publish("memory_profile", processes[0].pid) # end but maybe don't start +time.sleep(5) # mem_function will get tracker started +self.db.publish("memory_profile", processes[0].pid) # start successfully +input() +``` + +##### Step 3 +Just run slips.py and check if output is similar to below. + + + +After the test finishes running, it will not terminate and just wait for input. Press [Ctrl-C] to exit and [Enter] to continue with the run. + +### Implementation Specification +#### Profiler Interface + + + + +The ProfilerInterface is an abstract interface that defines all the methods that a profiler class should have. The methods defined are: _create_profiler, start, stop, and print. + +The CPUProfiler and MemoryProfiler classes are meant to be called by the main slips.py. Since the behaviors are different depending on the settings in slips.conf, these top level profiling classes act as Factory Method class, which takes on the identity of either DevProfiler or LiveProfiler depending on the class constructor parameters. +CPUProfiler Class +The CPUProfiler class is a profiler utility that allows you to profile CPU usage in different modes and output formats. It utilizes the ProfilerInterface, DevProfiler, and LiveProfiler classes to manage profiling behavior. + +##### Constructor +```python +def __init__(self, db, output, mode="dev", limit=20, interval=20): +``` + +db: A database or communication object used for publishing profiling results. This must be a redis db object. +output: The directory path where profiling results will be saved or communicated. +mode: The profiling mode. It can be "dev" (development) or "live" (live). +limit: The maximum number of function stats to display. +interval: The time interval (in seconds) between each sampling loop iteration. + +##### Methods +**_create_profiler(self)** +Creates and returns an instance of the appropriate profiler library based on the selected mode. Returns a profiler instance. + +**start(self)** +Starts the CPU profiling + +**stop(self)** +Stops the CPU profiling. + +**print(self)** +Prints the CPU profiling results. + +#### MemoryProfiler Class +The MemoryProfiler class is a profiler utility designed to profile memory usage in different modes and output formats. It utilizes the ProfilerInterface, DevProfiler, and LiveProfiler classes to manage memory profiling behavior. + +##### Class Attributes + +**profiler** +A class-level attribute representing the current profiler instance. This attribute can be an instance of DevProfiler or LiveProfiler depending on the mode chosen. + +##### Constructor + +```python +def __init__(self, output, db=None, mode="dev", multiprocess=True): +``` + +* output: The directory path where memory profiling results will be saved or communicated. +* db: A database or communication object used for publishing memory profiling results (only applicable in live mode). +* mode: The profiling mode. It can be "dev" (development) or "live" (live). +* multiprocess: A boolean indicating whether the profiler should support multi process memory profiling. + +##### Methods + +**_create_profiler(self)** +Creates and initializes an instance of the appropriate memory profiler based on the selected mode. This method is intended for internal use. + +**start(self)** +Starts the memory profiling. Depending on the mode, this method may initiate memory tracking. + +**stop(self)** +Stops the memory profiling. Depending on the mode, this method may stop memory tracking. + +**print(self)** +This method is currently a placeholder and doesn't implement memory profiling output. It can be overridden in the subclasses for specific functionality. + +#### Memory Profiler Structure + + + +The LiveProfiler class uses the Factory Method to change its behavior based on whether multiprocess profiling is desired. The control logic for this design pattern is almost identical to the MemoryProfiler and CPUProfiler classes. + +#### Memory Multi Process Profiler Structure + + + + +#### LiveMultiprocessProfiler Class + +The LiveMultiprocessProfiler class is a specialized profiler utility designed for memory profiling in a multiprocess environment. It leverages multiprocessing and threading to manage memory profiling for multiple processes simultaneously. +##### Class Attributes + +**original_process_class** +A class-level attribute that stores the original multiprocessing.Process class. It is restored once the profiler is stopped. + +**signal_handler_thread** +A thread that continuously checks a Redis database for signals to start and stop memory profiling for specific processes. + +**db** +A reference to a redis database object used for sending and receiving signals related to memory profiling. + +##### Constructor +```python +def __init__(self, db=None): +``` +db: A database or communication object used for sending and receiving memory profiling signals. + +##### Methods +**_create_profiler(self)** +This method is currently a placeholder and doesn't implement memory profiling. It can be overridden in subclasses for specific memory profiling implementations. + +**_handle_signal(self)** +Function that runs in signal_handler_thread. A continuous loop that checks a Redis channel for signals indicating which processes to start and stop memory profiling for. It manages the memory profiling for different processes based on the received signals. + +**_test_thread(self)** +This method is currently a placeholder for testing purposes. It sends test signals to start memory profiling for processes. + +**start(self)** +Overrides the start method of ProfilerInterface. It prepares the profiler to handle memory profiling in a multiprocess environment. + +**stop(self)** +Overrides the stop method of ProfilerInterface. It restores the original multiprocessing.Process class and cleans up resources. + +**print(self)** +This method is currently a placeholder and doesn't implement memory profiling output. It can be overridden in subclasses for specific output functionality. + +#### MultiprocessPatchMeta Class +A metaclass that provides patches and enhancements for the multiprocessing.Process class to support memory profiling in a multiprocess environment. + +##### Class attributes +**tracker** +Stores the memray Tracker object. Must be set to None when tracking is not active and set to memray.Tracker(destination=) when active. tracker.__enter__() starts tracking by entering the tracker object context. To stop tracking, tracker.__exit__(None, None, None) must be run. + +**signal_interval** +Sets the interval in seconds for the start_signal_thread and end_signal_thread to check start and end signals. + +**poll_interval** +Sets the interval in seconds which the main thread polls the start and end signals when they are set so that the event signal is polled until the signal is unset. This results in synchronous behavior for set_start_signal() and set_end_signal(). + +**port** +Sets the port for the tracker to listen on. + +##### Methods +Several methods have been added and overridden in the MultiprocessPatchMeta class to support memory profiling: + +**set_start_signal(self, block=False)** +Sets the start signal for memory profiling. If the start signal is set, this method triggers the memory profiling process to start. +- block (bool, optional): If True, the method will block until the start signal is processed. + +**set_end_signal(self, block=False)** +Sets the end signal for memory profiling. If the end signal is set, this method triggers the memory profiling process to end. +- block (bool, optional): If True, the method will block until the end signal is processed. + +**execute_tracker(self, destination)** +Initializes and executes the memory tracker for the current process. The tracker captures memory usage and relevant data. +- destination: The destination for memory profiling data. Typically, a socket destination. + +**start_tracker(self)** +Starts the memory tracker for the current process. Acquires the necessary locks and begins memory profiling. + +**end_tracker(self)** +Ends the memory tracker for the current process. Releases locks and stops memory profiling. + +**_check_start_signal(self)** +A background thread that continuously checks for the start signal. When the start signal is received, this method triggers the memory tracker to start. + +**_check_end_signal(self)** +A background thread that continuously checks for the end signal. When the end signal is received, this method triggers the memory tracker to end. + +**start(self)** +Overrides the start method of multiprocessing.Process. Extends the behavior of the original start method to include starting the background signal threads and memory tracker. + +**_pop_map(self)** +Removes the current process from the proc_map_global dictionary, which tracks active processes. + +**_release_lock(self)** +Releases the global lock (tracker_lock_global) that is acquired when profiling memory. Ensures that the lock is released when memory profiling ends. + +**_cleanup(self)** +Cleans up resources associated with the current process. Removes it from the proc_map_global dictionary and releases the lock. + +**patched_run(self, *args, \*\*\kwargs)** +Overrides the run method of the multiprocessing.Process class. Enhances the run method by adding memory profiling logic and signal handling. It starts and ends memory profiling based on the signals received. + +#### Notes +This section is going to focus on the multiprocess live memory profiling implementation since it is the most complex and every other feature is simple enough to understand by reading through. + +The following variables are used to keep track of the global state of the multiprocess live profiler. In order to function, the variables need to be accessible from the class LiveMultiprocessProfiler and from every instance of MultiprocessPatchMeta. +``` +mp_manager: SyncManager = None +tracker_lock_global: Lock = None +tracker_lock_holder_pid: SynchronizedBase = None +proc_map_global: Dict[int, multiprocessing.Process] = None +proc_map_lock_global: Lock = None +``` +These variables are initialized by LiveMultiprocessProfiler on initialization. +``` +mp_manager = multiprocessing.Manager() +tracker_lock_global = mp_manager.Lock() +tracker_lock_holder_pid = multiprocessing.Value("i", 0) +proc_map_global = {} +proc_map_lock_global = mp_manager.Lock() +``` + +**mp_manager** +Creates a multiprocessing.Manager() object which facilitates the sharing of resources between processes, allowing for all of the Process classes to access the global state. + +**tracker_lock_global** +A mutex lock which is acquired when the memray Tracker object is active in any process. The lock is then freed once the process possessing the lock stops tracking + +**tracker_lock_holder_pid** +Stores the PID of the process currently getting profiled. + +**proc_map_global** +Dictionary mapping PID to process object. Necessary for LiveMultiprocessProfiler to send signals to the other processes. + +**proc_map_lock_global** +A mutex lock which is acquired when modifying proc_map_global. + +#### Additional Notes +MultiprocessPatchMeta inherits ABCMeta because all modules inherit from class Module and multiprocess.Process. Normally metaclasses inherit from type but since Module inherits from ABC, MultiprocessPatchMeta must inherit from ABCMeta instead of type so prevent a metaclass conflict. +set_start_signal and set_end_signal in MultiprocessPatchMeta are supposed to behave synchronously when block=True is set in the parameter. This currently does not work because the line in MultiprocessPatchMeta.start_tracker: dest = memray.SocketDestination(server_port=self.port, address='127.0.0.1') blocks if the socket is not connected to with “memray live ”. diff --git a/metadata_manager.py b/metadata_manager.py index f208288ab..7f2f14ac8 100644 --- a/metadata_manager.py +++ b/metadata_manager.py @@ -148,7 +148,7 @@ def set_input_metadata(self): def check_if_port_is_in_use(self, port): if port == 6379: - # even if it's already in use, slips will override it + # even if it's already in use, slips should override it return False try: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) diff --git a/modules/flowalerts/flowalerts.py b/modules/flowalerts/flowalerts.py index a084cde3d..3c46be60f 100644 --- a/modules/flowalerts/flowalerts.py +++ b/modules/flowalerts/flowalerts.py @@ -290,10 +290,9 @@ def check_data_upload(self, sbytes, daddr, uid, profileid, twid): """ Set evidence when 1 flow is sending >= the flow_upload_threshold bytes """ - - if ( - self.is_ignored_ip_data_upload(daddr) + not daddr + or self.is_ignored_ip_data_upload(daddr) or not sbytes ): return False diff --git a/process_manager.py b/process_manager.py index a36db49a3..ace548f6f 100644 --- a/process_manager.py +++ b/process_manager.py @@ -29,9 +29,9 @@ def __init__(self, main): def start_output_process(self, current_stdout, stderr, slips_logfile): output_process = OutputProcess( - self.main.db, self.main.output_queue, self.main.args.output, + self.main.redis_port, self.termination_event, verbose=self.main.args.verbose, debug=self.main.args.debug, @@ -46,9 +46,9 @@ def start_output_process(self, current_stdout, stderr, slips_logfile): def start_profiler_process(self): profiler_process = ProfilerProcess( - self.main.db, self.main.output_queue, self.main.args.output, + self.main.redis_port, self.termination_event, profiler_queue=self.profiler_queue, ) @@ -64,9 +64,9 @@ def start_profiler_process(self): def start_evidence_process(self): evidence_process = EvidenceProcess( - self.main.db, self.main.output_queue, self.main.args.output, + self.main.redis_port, self.termination_event, ) evidence_process.start() @@ -81,9 +81,9 @@ def start_evidence_process(self): def start_input_process(self): input_process = InputProcess( - self.main.db, self.main.output_queue, self.main.args.output, + self.main.redis_port, self.termination_event, profiler_queue=self.profiler_queue, input_type=self.main.input_type, @@ -223,17 +223,14 @@ def load_modules(self): modules_to_call = self.get_modules(to_ignore)[0] loaded_modules = [] for module_name in modules_to_call: - # delete later - # if module_name != 'CPU Profiler': - # continue - # end if module_name in to_ignore: continue module_class = modules_to_call[module_name]["obj"] module = module_class( self.main.output_queue, - self.main.db, + self.main.args.output, + self.main.redis_port, self.termination_event, ) module.start() @@ -505,6 +502,5 @@ def shutdown_gracefully(self): else: f.write(f"[Process Manager] Slips didn't shutdown gracefully - {reason}\n") - exit() except KeyboardInterrupt: return False diff --git a/redis_manager.py b/redis_manager.py index 9455f460b..0ef94dbf0 100644 --- a/redis_manager.py +++ b/redis_manager.py @@ -6,6 +6,7 @@ import os import time import socket +import subprocess class RedisManager: def __init__(self, main): @@ -197,9 +198,10 @@ def get_pid_of_redis_server(self, port: int) -> str: Returns str(port) or false if there's no redis-server running on this port """ cmd = 'ps aux | grep redis-server' - cmd_output = os.popen(cmd).read() + process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) + cmd_output, _ = process.communicate() for line in cmd_output.splitlines(): - if str(port) in line: + if str(port).encode() in line: pid = line.split()[1] return pid return False @@ -269,9 +271,10 @@ def get_port_of_redis_server(self, pid: str): returns the port of the redis running on this pid """ cmd = 'ps aux | grep redis-server' - cmd_output = os.popen(cmd).read() + process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE) + cmd_output, _ = process.communicate() for line in cmd_output.splitlines(): - if str(pid) in line: + if str(pid).encode() in line: port = line.split(':')[-1] return port return False diff --git a/slips.py b/slips.py index 0a51fca45..2ab46fd34 100755 --- a/slips.py +++ b/slips.py @@ -20,6 +20,8 @@ import contextlib import multiprocessing from slips_files.common.imports import * +from slips_files.common.cpu_profiler import CPUProfiler +from slips_files.common.memory_profiler import MemoryProfiler from exclusiveprocess import Lock, CannotAcquireLock from redis_manager import RedisManager from metadata_manager import MetadataManager @@ -28,6 +30,14 @@ from checker import Checker from style import green + +from slips_files.core.inputProcess import InputProcess +from slips_files.core.outputProcess import OutputProcess +from slips_files.core.profilerProcess import ProfilerProcess +from slips_files.core.evidenceProcess import EvidenceProcess +from slips_files.core.database.database_manager import DBManager + + import signal import sys import os @@ -82,6 +92,86 @@ def __init__(self, testing=False): self.prepare_zeek_output_dir() self.twid_width = self.conf.get_tw_width() + def cpu_profiler_init(self): + self.cpuProfilerEnabled = slips.conf.get_cpu_profiler_enable() == 'yes' + self.cpuProfilerMode = slips.conf.get_cpu_profiler_mode() + self.cpuProfilerMultiprocess = slips.conf.get_cpu_profiler_multiprocess() == 'yes' + if self.cpuProfilerEnabled: + try: + if (self.cpuProfilerMultiprocess and self.cpuProfilerMode == "dev"): + args = sys.argv + if (args[-1] != "--no-recurse"): + tracer_entries = str(slips.conf.get_cpu_profiler_dev_mode_entries()) + viz_args = ['viztracer', '--tracer_entries', tracer_entries, '--max_stack_depth', '10', '-o', str(os.path.join(self.args.output, 'cpu_profiling_result.json'))] + viz_args.extend(args) + viz_args.append("--no-recurse") + print("Starting multiprocess profiling recursive subprocess") + subprocess.run(viz_args) + exit(0) + else: + self.cpuProfiler = CPUProfiler( + db=self.db, + output=self.args.output, + mode=slips.conf.get_cpu_profiler_mode(), + limit=slips.conf.get_cpu_profiler_output_limit(), + interval=slips.conf.get_cpu_profiler_sampling_interval() + ) + self.cpuProfiler.start() + except Exception as e: + print(e) + self.cpuProfilerEnabled = False + + def cpu_profiler_release(self): + if hasattr(self, 'cpuProfilerEnabled' ): + if self.cpuProfilerEnabled and not self.cpuProfilerMultiprocess: + self.cpuProfiler.stop() + self.cpuProfiler.print() + + def memory_profiler_init(self): + self.memoryProfilerEnabled = slips.conf.get_memory_profiler_enable() == "yes" + memoryProfilerMode = slips.conf.get_memory_profiler_mode() + memoryProfilerMultiprocess = slips.conf.get_memory_profiler_multiprocess() == "yes" + if self.memoryProfilerEnabled: + output_dir = os.path.join(slips.args.output,'memoryprofile/') + if not os.path.exists(output_dir): + os.makedirs(output_dir) + output_file = os.path.join(output_dir, 'memory_profile.bin') + self.memoryProfiler = MemoryProfiler(output_file, db=self.db, mode=memoryProfilerMode, multiprocess=memoryProfilerMultiprocess) + self.memoryProfiler.start() + + def memory_profiler_release(self): + if self.memoryProfilerEnabled: + self.memoryProfiler.stop() + + def memory_profiler_multiproc_test(self): + def target_function(): + print("Target function started") + time.sleep(5) + + def mem_function(): + print("Mem function started") + while True: + time.sleep(1) + array = [] + for i in range(1000000): + array.append(i) + processes = [] + num_processes = 3 + + for _ in range(num_processes): + process = multiprocessing.Process(target=target_function if _%2 else mem_function) + process.start() + processes.append(process) + + # Message passing + self.db.publish("memory_profile", processes[1].pid) # successful + # subprocess.Popen(["memray", "live", "1234"]) + time.sleep(5) # target_function will timeout and tracker will be cleared + self.db.publish("memory_profile", processes[0].pid) # end but maybe don't start + time.sleep(5) # mem_function will get tracker started + self.db.publish("memory_profile", processes[0].pid) # start successfully + input() + def get_slips_version(self): version_file = 'VERSION' with open(version_file, 'r') as f: @@ -122,7 +212,8 @@ def terminate_slips(self): """ if self.mode == 'daemonized': self.daemon.stop() - sys.exit(0) + if self.conf.get_cpu_profiler_enable() != "yes": + sys.exit(0) def update_local_TI_files(self): from modules.update_manager.update_manager import UpdateManager @@ -131,7 +222,10 @@ def update_local_TI_files(self): # so this function will only be allowed to run from 1 slips instance. with Lock(name="slips_ports_and_orgs"): # pass a dummy termination event for update manager to update orgs and ports info - update_manager = UpdateManager(self.output_queue, self.db, multiprocessing.Event()) + update_manager = UpdateManager(self.output_queue, + self.args.output, + self.redis_port, + multiprocessing.Event()) update_manager.update_ports_info() update_manager.update_org_files() except CannotAcquireLock: @@ -455,6 +549,12 @@ def start(self): 'commit': self.commit, 'branch': self.branch, }) + + self.cpu_profiler_init() + self.memory_profiler_init() + # uncomment line to see that memory profiler works correctly + # Should print out red text if working properly + # self.memory_profiler_multiproc_test() # if stdout is redirected to a file, # tell outputProcess.py to redirect it's output as well @@ -659,4 +759,7 @@ def sig_handler(sig, frame): daemon.start() else: # interactive mode + pass slips.start() + slips.cpu_profiler_release() + slips.memory_profiler_release() diff --git a/slips_files/common/_memory_profiler_example_no_import.py b/slips_files/common/_memory_profiler_example_no_import.py new file mode 100644 index 000000000..51ce80c54 --- /dev/null +++ b/slips_files/common/_memory_profiler_example_no_import.py @@ -0,0 +1,432 @@ +import memray +import glob +import os +import subprocess +from termcolor import colored +import time +import multiprocessing +from multiprocessing.managers import SyncManager +from multiprocessing.synchronize import Lock, Event +from multiprocessing.sharedctypes import SynchronizedBase +import threading +from typing import Dict, List +import psutil +import random +from abc import ABC, ABCMeta + +def proc_is_running(pid): + try: + process = psutil.Process(pid) + # Check if the process exists by accessing any attribute of the Process object + process.name() + return True + except psutil.NoSuchProcess: + return False + +class LiveMultiprocessProfiler: + original_process_class: multiprocessing.Process + signal_handler_thread: threading.Thread + tracker_possessor: int + db = None + + def __init__(self, db=None): + self.original_process_class = multiprocessing.Process + global mp_manager + global tracker_lock_global + global tracker_lock_holder_pid + global proc_map_global + global proc_map_lock_global + mp_manager = multiprocessing.Manager() + tracker_lock_global = mp_manager.Lock() # process holds when running profiling + tracker_lock_holder_pid = multiprocessing.Value("i", 0) # process that holds the lock + proc_map_global = {} # port to process object mapping + proc_map_lock_global = mp_manager.Lock() # hold when modifying proc_map_global + self.db = db + self.pid_channel = self.db + + def _create_profiler(self): + pass + + def _handle_signal(self): + global proc_map_global + global tracker_lock_holder_pid + while True: + # check redis channel + # poll for signal + timeout = 0.01 + msg: str = None + pid_to_profile: int = None + while not self.pid_channel.empty(): + msg = self.pid_channel.get() + print(f"Msg {msg}") + pid: int = None + try: + pid = int(msg) + except TypeError: + continue + if pid in proc_map_global.keys(): + if proc_is_running(pid): + pid_to_profile = pid + else: + try: + proc_map_global.pop(pid) + except KeyError: + pass + if pid_to_profile: + print(colored(f"Sending end signal {tracker_lock_holder_pid.value}", "red")) + if tracker_lock_holder_pid.value in proc_map_global.keys(): + print(proc_map_global[tracker_lock_holder_pid.value]) + proc_map_global[tracker_lock_holder_pid.value].set_end_signal() + print(colored(f"Sending start signal {pid_to_profile}", "red")) + proc_map_global[pid_to_profile].set_start_signal() + #send stop first, send start new process + + time.sleep(1) + + def _test_thread(self): + global proc_map_global + while True: + print("Test thread:", proc_map_global) + if len(proc_map_global): + pid = random.choice(list(proc_map_global.keys())) + self.db.put(pid) + print(colored(f"Published {pid}", "red")) + break + time.sleep(1) + + def start(self): + multiprocessing.Process = MultiprocessPatchMeta('Process', (multiprocessing.Process,), {}) + self.signal_handler_thread = threading.Thread(target=self._handle_signal, daemon=True) + self.signal_handler_thread.start() + #Remove Later + # self.test_thread = threading.Thread(target=self._test_thread, daemon=True) + # self.test_thread.start() + + def stop(self): + multiprocessing.Process = self.original_process_class + + def print(self): + pass + +class MultiprocessPatch(multiprocessing.Process): + tracker: memray.Tracker = None + tracker_start: Event = None + tracker_end: Event = None + signal_interval: int = 1 # sleep time for checking start and end signals to process + poll_interval: int = 1 # sleep time for checking if signal has finished processing + port = 1234 + def __init__(self, *args, **kwargs): + super(MultiprocessPatch, self).__init__(*args, **kwargs) + self.tracker_start = multiprocessing.Event() + self.tracker_end = multiprocessing.Event() + + def set_start_signal(self, block=False): + print(f"set start signal {self.pid}") + if self.tracker_start: + self.tracker_start.set() + while block and self.tracker_start.is_set(): + time.sleep(self.poll_interval) + + def set_end_signal(self, block=False): + print(f"set end signal {self.pid}") + if self.tracker_end: + self.tracker_end.set() + while block and self.tracker_start.is_set(): + time.sleep(self.poll_interval) + + def execute_tracker(self, destination): + self.tracker = memray.Tracker(destination=destination) + + def start_tracker(self): + global tracker_lock_global + global tracker_lock_holder_pid + print(colored(f"start_tracker lock {self.pid}", "red")) + if not self.tracker and tracker_lock_global.acquire(blocking=False): + print(colored(f"start_tracker memray at PID {self.pid} started {self.port}", "red")) + tracker_lock_holder_pid.value = self.pid + print(colored(f"start_tracker lock holder pid {tracker_lock_holder_pid.value}", "red")) + dest = memray.SocketDestination(server_port=self.port, address='127.0.0.1') + self.tracker = memray.Tracker(destination=dest) + self.tracker.__enter__() + + def end_tracker(self): + global tracker_lock_global + global tracker_lock_holder_pid + print(f"end_tracker Lock Holder {tracker_lock_holder_pid.value}, {self.tracker}") + if self.tracker: + print(colored(f"end_tracker memray at PID {self.pid} ended", "red")) + self.tracker.__exit__(None, None, None) + self.tracker = None + tracker_lock_holder_pid.value = 0 + tracker_lock_global.release() + + def _check_start_signal(self): + while True: + while not self.tracker_start.is_set(): + time.sleep(self.signal_interval) + continue + self.start_tracker() + self.tracker_start.clear() + + def _check_end_signal(self): + while True: + while not self.tracker_end.is_set(): + time.sleep(self.signal_interval) + continue + self.end_tracker() + self.tracker_end.clear() + + def start(self) -> None: + super().start() + global proc_map_global + global proc_map_lock_global + proc_map_lock_global.acquire() + proc_map_global[self.pid] = self + proc_map_lock_global.release() + + def _pop_map(self): + global proc_map_global + global proc_map_lock_global + proc_map_lock_global.acquire() + try: + proc_map_global.pop(self.pid) + except KeyError: + print(f"_pop_map {self.pid} no longer in memory profile map, continuing...") + proc_map_lock_global.release() + + def _release_lock(self): + global tracker_lock_global + global tracker_lock_holder_pid + if tracker_lock_holder_pid.value == self.pid: + tracker_lock_global.release() + + def _cleanup(self): + self._pop_map() + self._release_lock() + + def run(self): + print(f"Child process started - PID: {self.pid}") + start_signal_thread = threading.Thread(target=self._check_start_signal, daemon=True) + start_signal_thread.start() + end_signal_thread = threading.Thread(target=self._check_end_signal, daemon=True) + end_signal_thread.start() + super().run() + self._cleanup() + +mp_manager: SyncManager = multiprocessing.Manager() +tracker_lock_global: Lock = mp_manager.Lock() # process holds when running profiling +tracker_lock_holder_pid: SynchronizedBase = multiprocessing.Value("i", 0) # process that holds the lock +proc_map_global: Dict[int, MultiprocessPatch] = {} # port to process object mapping +proc_map_lock_global: Lock = mp_manager.Lock() # hold when modifying proc_map_global + +def target_function(): + print("Target function started") + time.sleep(5) + +def mem_function(): + print("Mem function started") + while True: + time.sleep(1) + array = [] + for i in range(1000000): + array.append(i) + +class MultiprocessPatchMeta(ABCMeta): + def __new__(cls, name, bases, dct): + new_cls = super().__new__(cls, name, bases, dct) + new_cls.tracker: memray.Tracker = None + new_cls.tracker_start: Event = None + new_cls.signal_interval: int = 1 + new_cls.poll_interval: int = 1 + new_cls.port = 1234 + return new_cls + + def __init__(cls, name, bases, dct): + super().__init__(name, bases, dct) + def __init__(self, *args, **kwargs): + super(cls, self).__init__(*args, **kwargs) + self.tracker_start = multiprocessing.Event() + self.tracker_end = multiprocessing.Event() + cls.__init__ = __init__ + + def set_start_signal(self, block=False): + print(f"set start signal {self.pid}") + if self.tracker_start: + self.tracker_start.set() + while block and self.tracker_start.is_set(): + time.sleep(self.poll_interval) + cls.set_start_signal = set_start_signal + + def set_end_signal(self, block=False): + print(f"set end signal {self.pid}") + if self.tracker_end: + self.tracker_end.set() + while block and self.tracker_start.is_set(): + time.sleep(self.poll_interval) + cls.set_end_signal = set_end_signal + + def execute_tracker(self, destination): + self.tracker = memray.Tracker(destination=destination) + cls.execute_tracker = execute_tracker + + def start_tracker(self): + global tracker_lock_global + global tracker_lock_holder_pid + print(colored(f"start_tracker lock {self.pid}", "red")) + if not self.tracker and tracker_lock_global.acquire(blocking=False): + print(colored(f"start_tracker memray at PID {self.pid} started {self.port}", "red")) + tracker_lock_holder_pid.value = self.pid + print(colored(f"start_tracker lock holder pid {tracker_lock_holder_pid.value}", "red")) + dest = memray.SocketDestination(server_port=self.port, address='127.0.0.1') + self.tracker = memray.Tracker(destination=dest) + self.tracker.__enter__() + cls.start_tracker = start_tracker + + def end_tracker(self): + global tracker_lock_global + global tracker_lock_holder_pid + print(f"end_tracker Lock Holder {tracker_lock_holder_pid.value}, {self.tracker}") + if self.tracker: + print(colored(f"end_tracker memray at PID {self.pid} ended", "red")) + self.tracker.__exit__(None, None, None) + self.tracker = None + tracker_lock_holder_pid.value = 0 + tracker_lock_global.release() + cls.end_tracker = end_tracker + + def _check_start_signal(self): + while True: + while not self.tracker_start.is_set(): + time.sleep(self.signal_interval) + continue + self.start_tracker() + self.tracker_start.clear() + cls._check_start_signal = _check_start_signal + + def _check_end_signal(self): + while True: + while not self.tracker_end.is_set(): + time.sleep(self.signal_interval) + continue + self.end_tracker() + self.tracker_end.clear() + cls._check_end_signal = _check_end_signal + + def start(self) -> None: + super(cls, self).start() + global proc_map_global + global proc_map_lock_global + proc_map_lock_global.acquire() + proc_map_global[self.pid] = self + proc_map_lock_global.release() + cls.start = start + + def _pop_map(self): + global proc_map_global + global proc_map_lock_global + proc_map_lock_global.acquire() + try: + proc_map_global.pop(self.pid) + except KeyError: + print(f"_pop_map {self.pid} no longer in memory profile map, continuing...") + proc_map_lock_global.release() + cls._pop_map = _pop_map + + def _release_lock(self): + global tracker_lock_global + global tracker_lock_holder_pid + if tracker_lock_holder_pid.value == self.pid: + tracker_lock_global.release() + cls._release_lock = _release_lock + + def _cleanup(self): + self._pop_map() + self._release_lock() + cls._cleanup = _cleanup + + # Preserve the original run method + original_run = cls.run + + # # Define a new run method that adds the print statements + def patched_run(self, *args, **kwargs): + print(f"Child process started - PID: {self.pid}") + start_signal_thread = threading.Thread(target=self._check_start_signal, daemon=True) + start_signal_thread.start() + end_signal_thread = threading.Thread(target=self._check_end_signal, daemon=True) + end_signal_thread.start() + original_run(self, *args, **kwargs) + self._cleanup() + + # Replace the original run method with the new one + cls.run = patched_run + +# Apply the metaclass to multiprocessing.Process +multiprocessing.Process = MultiprocessPatchMeta('Process', (multiprocessing.Process,), {}) + +class Module(ABC): + def __init__(self): + multiprocessing.Process.__init__(self) + def main(self): + print("Module main") + def run(self): + print("Module run") + self.main() + +class A(Module, multiprocessing.Process): + def main(self): + print("Target function started") + time.sleep(5) + +class B(Module, multiprocessing.Process): + def main(self): + print("Mem function started") + while True: + time.sleep(1) + array = [] + for i in range(1000000): + array.append(i) + +if __name__ == "__main__": + # Notes + # set signal start and end are non-blocking, only sends the signal but doesn't guarantee success + # start_tracker will block until lock is acquired and memray is connected, can change later + # end_tracker works even if memray client quits + + + db = multiprocessing.Queue() + # profiler = LiveMultiprocessProfiler(db=db) + # profiler.start() + + p = A() + p.start() + pp = B() + pp.start() + p.join() + pp.join() + exit() + processes: List[MultiprocessPatch] = [] + num_processes = 3 + + for _ in range(num_processes): + process = multiprocessing.Process(target=target_function if _%2 else mem_function) + process.start() + processes.append(process) + + # Message passing + db.put(processes[1].pid) # successful + time.sleep(5) # target_function will timeout and tracker will be cleared + db.put(processes[0].pid) # end but maybe don't start + time.sleep(5) # mem_function will get tracker started + db.put(processes[0].pid) # start successfully + + # Direct process access + # processes[0].set_start_signal() + # time.sleep(5) + # processes[0].set_end_signal() + # time.sleep(2) + # processes[1].set_start_signal() + # time.sleep(10) + # processes[0].set_end_signal() + + for process in processes: + process.join() \ No newline at end of file diff --git a/slips_files/common/abstracts.py b/slips_files/common/abstracts.py index d2d670c75..a3ac46c1f 100644 --- a/slips_files/common/abstracts.py +++ b/slips_files/common/abstracts.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod # common imports for all modules +from slips_files.core.database.database_manager import DBManager from multiprocessing import Event from slips_files.common.slips_utils import utils from multiprocessing import Process @@ -11,10 +12,15 @@ class Module(ABC): name = '' description = 'Template module' authors = ['Template Author'] - def __init__(self, output_queue, db, termination_event, **kwargs): + def __init__(self, + output_queue, + output_dir, + redis_port, + termination_event, + **kwargs): Process.__init__(self) self.output_queue = output_queue - self.db = db + self.db = DBManager(output_dir, output_queue, redis_port) self.msg_received = False # used to tell all slips.py children to stop self.termination_event: Event = termination_event @@ -127,6 +133,10 @@ def run(self): return True + def __del__(self): + self.db.close() + + class Core(Module, Process): """ Interface for all Core files placed in slips_files/core/ @@ -136,7 +146,12 @@ class Core(Module, Process): authors = ['Name of the author creating the class'] def __init__( - self, db, output_queue, output_dir, termination_event, **kwargs + self, + output_queue, + output_dir, + redis_port, + termination_event, + **kwargs ): """ contains common initializations in all core files in slips_files/core/ @@ -144,12 +159,11 @@ def __init__( in this file """ Process.__init__(self) - self.output_queue = output_queue self.output_dir = output_dir # used to tell all slips.py children to stop self.termination_event: Event = termination_event - self.db = db + self.db = DBManager(output_dir, output_queue, redis_port) self.msg_received = False self.init(**kwargs) @@ -175,3 +189,24 @@ def run(self): self.print(traceback.format_exc(), 0, 1) return True + + def __del__(self): + self.db.close() + + +class ProfilerInterface(ABC): + @abstractmethod + def _create_profiler(self): + pass + + @abstractmethod + def start(self): + pass + + @abstractmethod + def stop(self): + pass + + @abstractmethod + def print(self): + pass \ No newline at end of file diff --git a/slips_files/common/argparse.py b/slips_files/common/argparse.py index 808674ad8..b454c3426 100644 --- a/slips_files/common/argparse.py +++ b/slips_files/common/argparse.py @@ -283,6 +283,11 @@ def parse_arguments(self): required=False, help='Read flows from a module other than input process.', ) + self.add_argument( + '--no-recurse', + action='store_true', + help='Internal use only, prevents infinite recursion for cpu profiler dev mode multiprocess tracking' + ) try: self.add_argument( '-h', '--help', action='store_true', help='command line help', diff --git a/slips_files/common/config_parser.py b/slips_files/common/config_parser.py index a64378303..7fe40d9f1 100644 --- a/slips_files/common/config_parser.py +++ b/slips_files/common/config_parser.py @@ -743,7 +743,30 @@ def get_disabled_modules(self, input_type) -> list: to_ignore.append('CYST') return to_ignore - - - - + + def get_cpu_profiler_enable(self): + return self.read_configuration('Profiling', 'cpu_profiler_enable', 'no') + + def get_cpu_profiler_mode(self): + return self.read_configuration('Profiling', 'cpu_profiler_mode', 'dev') + + def get_cpu_profiler_multiprocess(self): + return self.read_configuration('Profiling', 'cpu_profiler_multiprocess', 'yes') + + def get_cpu_profiler_output_limit(self) -> int: + return int(self.read_configuration('Profiling', 'cpu_profiler_output_limit', 20)) + + def get_cpu_profiler_sampling_interval(self) -> int: + return int(self.read_configuration('Profiling', 'cpu_profiler_sampling_interval', 5)) + + def get_cpu_profiler_dev_mode_entries(self) -> int: + return int(self.read_configuration('Profiling', 'cpu_profiler_dev_mode_entries', 1000000)) + + def get_memory_profiler_enable(self): + return self.read_configuration('Profiling', 'memory_profiler_enable', 'no') + + def get_memory_profiler_mode(self): + return self.read_configuration('Profiling', 'memory_profiler_mode', 'dev') + + def get_memory_profiler_multiprocess(self): + return self.read_configuration('Profiling', 'memory_profiler_multiprocess', 'yes') \ No newline at end of file diff --git a/slips_files/common/cpu_profiler.py b/slips_files/common/cpu_profiler.py new file mode 100644 index 000000000..7fc2f7cfe --- /dev/null +++ b/slips_files/common/cpu_profiler.py @@ -0,0 +1,95 @@ +import viztracer +import time +import threading +import yappi +import io +import pstats +import os + +from slips_files.common.abstracts import ProfilerInterface + +class CPUProfiler(ProfilerInterface): + def __init__(self, db, output, mode="dev", limit=20, interval=20): + valid_modes = ["dev", "live"] + if mode not in valid_modes: + print("cpu_profiler_mode = " + mode + " is invalid, must be one of " + + str(valid_modes) + ", CPU Profiling will be disabled") + if mode == "dev": + self.profiler = DevProfiler(output) + if mode == "live": + self.profiler = LiveProfiler(db, limit, interval) + + def _create_profiler(self): + self.profiler._create_profiler() + + def start(self): + print("CPU Profiler Started") + self.profiler.start() + + def stop(self): + self.profiler.stop() + print("CPU Profiler Ended") + + def print(self): + self.profiler.print() + +class DevProfiler(ProfilerInterface): + def __init__(self, output): + self.profiler = self._create_profiler() + self.output = output + + def _create_profiler(self): + return viztracer.VizTracer() + + def start(self): + self.profiler.start() + + def stop(self): + self.profiler.stop() + + def print(self): + result_path = os.path.join(self.output, 'cpu_profiling_result.json' ) + self.profiler.save(result_path) + +class LiveProfiler(ProfilerInterface): + def __init__(self, db, limit=20, interval=20): + self.profiler = self._create_profiler() + self.limit = limit + self.interval = interval + self.is_running = False + self.timer_thread = threading.Thread(target=self._sampling_loop) + self.db = db + self.stats = None + + def _create_profiler(self): + return yappi + + def start(self): + if not self.is_running: + self.is_running = True + self.profiler.start() + self.timer_thread.start() + + def stop(self): + if self.is_running: + self.is_running = False + self.profiler.stop() + + def print(self): + self.stats.print_stats(self.limit) + + def _sampling_loop(self): + stringio = io.StringIO() + while self.is_running: + # replace the print with a redis update + self.profiler.clear_stats() + + time.sleep(self.interval) + + self.stats = pstats.Stats(stream=stringio) + self.stats.add(self.profiler.convert2pstats(self.profiler.get_func_stats())) + self.stats.sort_stats('cumulative') + self.print() #prints to stringio, not stdout + + self.db.publish('cpu_profile', stringio.getvalue()) + \ No newline at end of file diff --git a/slips_files/common/memory_profiler.py b/slips_files/common/memory_profiler.py new file mode 100644 index 000000000..18ce34bd2 --- /dev/null +++ b/slips_files/common/memory_profiler.py @@ -0,0 +1,366 @@ +import memray +import glob +import os +import subprocess +from termcolor import colored +from slips_files.common.abstracts import ProfilerInterface +import time +import multiprocessing +from multiprocessing.managers import SyncManager +from multiprocessing.synchronize import Lock, Event +from multiprocessing.sharedctypes import SynchronizedBase +import threading +from typing import Dict +import psutil +import random +from abc import ABCMeta + +class MemoryProfiler(ProfilerInterface): + profiler = None + def __init__(self, output, db=None, mode="dev", multiprocess=True): + valid_modes = ["dev", "live"] + if mode not in valid_modes: + print("memory_profiler_mode = " + mode + " is invalid, must be one of " + + str(valid_modes) + ", Memory Profiling will be disabled") + if mode == "dev": + self.profiler = DevProfiler(output, multiprocess) + elif mode == "live": + self.profiler = LiveProfiler(multiprocess, db=db) + + def _create_profiler(self): + self.profiler._create_profiler() + + def start(self): + print(colored("Memory Profiler Started", 'green')) + self.profiler.start() + + def stop(self): + self.profiler.stop() + print(colored("Memory Profiler Ended", 'green')) + + def print(self): + pass + +class DevProfiler(ProfilerInterface): + output = None + profiler = None + multiprocess = None + def __init__(self, output, multiprocess): + self.output = output + self.multiprocess = multiprocess + self.profiler = self._create_profiler() + + def _create_profiler(self): + return memray.Tracker(file_name=self.output, follow_fork=self.multiprocess) + + def start(self): + self.profiler.__enter__() + + def stop(self): + self.profiler.__exit__(None, None, None) + print(colored("Converting memory profile bin files to html...", 'green')) + output_files = glob.glob(self.output + '*') + directory = os.path.dirname(self.output) + flamegraph_dir = directory + '/flamegraph/' + if not os.path.exists(flamegraph_dir): + os.makedirs(flamegraph_dir) + table_dir = directory + '/table/' + if not os.path.exists(table_dir): + os.makedirs(table_dir) + for file in output_files: + filename = os.path.basename(file) + flame_output = flamegraph_dir + filename + '.html' + subprocess.run(['memray', 'flamegraph', '--temporal', '--leaks', '--split-threads', '--output', flame_output, file]) + table_output = table_dir + filename + '.html' + subprocess.run(['memray', 'table', '--output', table_output, file]) + + def print(self): + pass + +class LiveProfiler(ProfilerInterface): + multiprocess = None + profiler = None + def __init__(self, multiprocess=False, db=None): + self.multiprocess = multiprocess + if multiprocess: + self.profiler=LiveMultiprocessProfiler(db=db) + else: + self.profiler=LiveSingleProcessProfiler() + def _create_profiler(self): + self.profiler._create_profiler() + + def start(self): + self.profiler.start() + + def stop(self): + self.profiler.stop() + + def print(self): + self.profiler.print() + +class LiveSingleProcessProfiler(ProfilerInterface): + profiler = None + port = 5000 + def __init__(self): + self.profiler = self._create_profiler() + def _create_profiler(self): + print("Memory profiling running on port " + str(self.port)) + print("Connect to continue") + with open(os.devnull, 'w') as devnull: + subprocess.Popen(["memray", "live", str(self.port)], stdout=devnull) + dest = memray.SocketDestination(server_port=self.port, address='127.0.0.1') + return memray.Tracker(destination=dest) + + def start(self): + self.profiler.__enter__() + + def stop(self): + self.profiler.__exit__(None, None, None) + + def print(self): + pass + +def proc_is_running(pid): + try: + process = psutil.Process(pid) + # Check if the process exists by accessing any attribute of the Process object + process.name() + return True + except psutil.NoSuchProcess: + return False + +class LiveMultiprocessProfiler(ProfilerInterface): + # restores the original process behavior once profiler is stopped + original_process_class: multiprocessing.Process + # thread checks redis db for which process to start profiling + signal_handler_thread: threading.Thread + db = None + def __init__(self, db=None): + self.original_process_class = multiprocessing.Process + global mp_manager + global tracker_lock_global + global tracker_lock_holder_pid + global proc_map_global + global proc_map_lock_global + mp_manager = multiprocessing.Manager() + tracker_lock_global = mp_manager.Lock() # process holds when running profiling + tracker_lock_holder_pid = multiprocessing.Value("i", 0) # process that holds the lock + proc_map_global = {} # port to process object mapping + proc_map_lock_global = mp_manager.Lock() # hold when modifying proc_map_global + self.db = db + self.pid_channel = self.db.subscribe("memory_profile") + + def _create_profiler(self): + pass + # on signal received, check if pid is valid and stop currently profiled process. Then start new pid profiling. + def _handle_signal(self): + global proc_map_global + global tracker_lock_holder_pid + while True: + # check redis channel + # poll for signal + timeout = 0.01 + msg: str = self.pid_channel.get_message(timeout=timeout) + pid_to_profile: int = None + while msg: + # print(f"Msg {msg}") + pid: int = None + try: + pid = int(msg['data']) + except ValueError: + msg = self.pid_channel.get_message(timeout=timeout) + continue + if pid in proc_map_global.keys(): + if proc_is_running(pid): + pid_to_profile = pid + else: + try: + proc_map_global.pop(pid) + except KeyError: + pass + msg = self.pid_channel.get_message(timeout=timeout) + + if pid_to_profile: + print(colored(f"Sending end signal {tracker_lock_holder_pid.value}", "red")) + if tracker_lock_holder_pid.value in proc_map_global.keys(): + print(proc_map_global[tracker_lock_holder_pid.value]) + proc_map_global[tracker_lock_holder_pid.value].set_end_signal() + print(colored(f"Sending start signal {pid_to_profile}", "red")) + proc_map_global[pid_to_profile].set_start_signal() + #send stop first, send start new process + + time.sleep(1) + # set pid in redis channel for testing + def _test_thread(self): + global proc_map_global + while True: + if len(proc_map_global): + pid = random.choice(list(proc_map_global.keys())) + self.db.publish("memory_profile", pid) + print(colored(f"Published {pid}", "red")) + time.sleep(5) + subprocess.Popen(["memray", "live", "1234"]) + break + time.sleep(1) + + def start(self): + multiprocessing.Process = MultiprocessPatchMeta('Process', (multiprocessing.Process,), {}) + self.signal_handler_thread = threading.Thread(target=self._handle_signal, daemon=True) + self.signal_handler_thread.start() + #Remove Later + # self.test_thread = threading.Thread(target=self._test_thread, daemon=True) + # self.test_thread.start() + + def stop(self): + multiprocessing.Process = self.original_process_class + + def print(self): + pass + +class MultiprocessPatchMeta(ABCMeta): + def __new__(cls, name, bases, dct): + new_cls = super().__new__(cls, name, bases, dct) + new_cls.tracker: memray.Tracker = None + new_cls.signal_interval: int = 1 # sleep time in sec for checking start and end signals to process + new_cls.poll_interval: int = 1 # sleep time in sec for checking if signal has finished processing + new_cls.port = 1234 + return new_cls + + def __init__(cls, name, bases, dct): + super().__init__(name, bases, dct) + def __init__(self, *args, **kwargs): + super(cls, self).__init__(*args, **kwargs) + self.tracker_start = multiprocessing.Event() + self.tracker_end = multiprocessing.Event() + cls.__init__ = __init__ + + # synchonous signal processing, block until event is processed. Then returns. + def set_start_signal(self, block=False): + print(f"set start signal {self.pid}") + if self.tracker_start: + self.tracker_start.set() + while block and self.tracker_start.is_set(): + time.sleep(self.poll_interval) + cls.set_start_signal = set_start_signal + + # synchonous signal as well. + def set_end_signal(self, block=False): + print(f"set end signal {self.pid}") + if self.tracker_end: + self.tracker_end.set() + while block and self.tracker_start.is_set(): + time.sleep(self.poll_interval) + cls.set_end_signal = set_end_signal + + # start profiling current process. Profiles current process context. + def execute_tracker(self, destination): + self.tracker = memray.Tracker(destination=destination) + cls.execute_tracker = execute_tracker + + def start_tracker(self): + global tracker_lock_global + global tracker_lock_holder_pid + print(colored(f"start_tracker lock {self.pid}", "red")) + if not self.tracker and tracker_lock_global.acquire(blocking=False): + print(colored(f"start_tracker memray at PID {self.pid} started {self.port}", "red")) + tracker_lock_holder_pid.value = self.pid + print(colored(f"start_tracker lock holder pid {tracker_lock_holder_pid.value}", "red")) + dest = memray.SocketDestination(server_port=self.port, address='127.0.0.1') + self.tracker = memray.Tracker(destination=dest) + self.tracker.__enter__() + cls.start_tracker = start_tracker + + def end_tracker(self): + global tracker_lock_global + global tracker_lock_holder_pid + print(f"end_tracker Lock Holder {tracker_lock_holder_pid.value}, {self.tracker}") + if self.tracker: + print(colored(f"end_tracker memray at PID {self.pid} ended", "red")) + self.tracker.__exit__(None, None, None) + self.tracker = None + tracker_lock_holder_pid.value = 0 + tracker_lock_global.release() + cls.end_tracker = end_tracker + + # checks if the start signal is set. Runs in a different thread. + def _check_start_signal(self): + while True: + while not self.tracker_start.is_set(): + time.sleep(self.signal_interval) + continue + self.start_tracker() + self.tracker_start.clear() + cls._check_start_signal = _check_start_signal + + # checks if the end signal is set. Runs in a different thread. + def _check_end_signal(self): + while True: + while not self.tracker_end.is_set(): + time.sleep(self.signal_interval) + continue + self.end_tracker() + self.tracker_end.clear() + cls._check_end_signal = _check_end_signal + + # Sets up data before running. super() starts first to set initialize pid. Then adds itself to proc_map_global + def start(self) -> None: + super(cls, self).start() + global proc_map_global + global proc_map_lock_global + proc_map_lock_global.acquire() + proc_map_global[self.pid] = self + proc_map_lock_global.release() + cls.start = start + + # Removes itself from the proc_map_global. Intended to run when process stops. + def _pop_map(self): + global proc_map_global + global proc_map_lock_global + proc_map_lock_global.acquire() + try: + proc_map_global.pop(self.pid) + except KeyError: + print(f"_pop_map {self.pid} no longer in memory profile map, continuing...") + proc_map_lock_global.release() + cls._pop_map = _pop_map + + # release tracker_lock_global, which must be acquired while profiling. + def _release_lock(self): + global tracker_lock_global + global tracker_lock_holder_pid + if tracker_lock_holder_pid.value == self.pid: + tracker_lock_global.release() + cls._release_lock = _release_lock + + def _cleanup(self): + self._pop_map() + self._release_lock() + cls._cleanup = _cleanup + + # Preserve the original run method + original_run = cls.run + + # # Define a new run method that adds the print statements + def patched_run(self, *args, **kwargs): + print(f"Child process started - PID: {self.pid}") + start_signal_thread = threading.Thread(target=self._check_start_signal, daemon=True) + start_signal_thread.start() + end_signal_thread = threading.Thread(target=self._check_end_signal, daemon=True) + end_signal_thread.start() + original_run(self, *args, **kwargs) + self._cleanup() + + # Replace the original run method with the new one + cls.run = patched_run + +# All the following should have a shared state between all instances of MultiprocessPatch +# and be accessible from the separate processes. +# proc_map_global only needs to be accessible from the main process. This is because +# adding a processes pid as a key creates a pickle error. +# tracker_lock_global must be possessed by the process currently being profiled and released +# when profiling is completed so another process can be profiled. +mp_manager: SyncManager = None +tracker_lock_global: Lock = None # process holds when running profiling +tracker_lock_holder_pid: SynchronizedBase = None # process that holds the lock +proc_map_global: Dict[int, multiprocessing.Process] = None # port to process object mapping +proc_map_lock_global: Lock = None # hold when modifying proc_map_global diff --git a/slips_files/common/slips_utils.py b/slips_files/common/slips_utils.py index e83a2c9fa..b5b069362 100644 --- a/slips_files/common/slips_utils.py +++ b/slips_files/common/slips_utils.py @@ -1,5 +1,6 @@ import hashlib from datetime import datetime, timedelta +from re import findall import validators from git import Repo import socket @@ -9,6 +10,10 @@ import os import sys import ipaddress +import communityid +from hashlib import sha1 +from base64 import b64encode + IS_IN_A_DOCKER_CONTAINER = os.environ.get('IS_IN_A_DOCKER_CONTAINER', False) @@ -59,6 +64,7 @@ def __init__(self): # this format will be used accross all modules and logfiles of slips self.alerts_format = '%Y/%m/%d %H:%M:%S.%f%z' self.local_tz = self.get_local_timezone() + self.community_id = communityid.CommunityID() def get_cidr_of_ip(self, ip): """ @@ -424,6 +430,72 @@ def get_time_diff(self, start_time: float, end_time: float, return_type='seconds return units[return_type] + def remove_milliseconds_decimals(self, ts: str) -> str: + """ + remove the milliseconds from the given ts + :param ts: time in unix format + """ + ts = str(ts) + if '.' not in ts: + return ts + + return ts.split('.')[0] + + + + def assert_microseconds(self, ts: str): + """ + adds microseconds to the given ts if not present + :param ts: unix ts + :return: ts + """ + ts = self.convert_format(ts, 'unixtimestamp') + + ts = str(ts) + # pattern of unix ts with microseconds + pattern = r'\b\d+\.\d{6}\b' + matches = findall(pattern, ts) + + if not matches: + # fill the missing microseconds and milliseconds with 0 + # 6 is the decimals we need after the . in the unix ts + ts = ts + "0" * (6 - len(ts.split('.')[-1])) + return ts + + def get_aid(self, flow): + """ + calculates the flow SHA1(cid+ts) aka All-ID of the flow + because we need the flow ids to be unique to be able to compare them + """ + #TODO document this + community_id = self.get_community_id(flow) + ts = flow.starttime + ts: str = self.assert_microseconds(ts) + + aid = f"{community_id}-{ts}" + + # convert the input string to bytes (since hashlib works with bytes) + aid: str = sha1(aid.encode('utf-8')).hexdigest() + aid: str = b64encode(aid.encode()).decode() + return aid + + + def get_community_id(self, flow): + """ + calculates the flow community id based of the protocol + """ + proto = flow.proto.lower() + cases = { + 'tcp': communityid.FlowTuple.make_tcp, + 'udp': communityid.FlowTuple.make_udp, + 'icmp': communityid.FlowTuple.make_icmp, + } + try: + tpl = cases[proto](flow.saddr, flow.daddr, flow.sport, flow.dport) + return self.community_id.calc(tpl) + except KeyError: + # proto doesn't have a community_id.FlowTuple method + return '' def IDEA_format( self, diff --git a/slips_files/core/database/database_manager.py b/slips_files/core/database/database_manager.py index 8bd5fb536..97f0e825d 100644 --- a/slips_files/core/database/database_manager.py +++ b/slips_files/core/database/database_manager.py @@ -1,6 +1,6 @@ from slips_files.core.database.redis_db.database import RedisDB from slips_files.core.database.sqlite_db.database import SQLiteDB - +from slips_files.common.config_parser import ConfigParser class DBManager: """ @@ -8,39 +8,35 @@ class DBManager: each method added to any of the dbs should have a handler in here """ - _obj = None - # Stores instances per port - # this class is a singelton per redis port.meaning that each redis port will only be allowed to use - # exactly 1 instance - _instances = {} - - def __new__( - cls, + def __init__( + self, output_dir, output_queue, redis_port, start_sqlite=True, **kwargs ): - cls.output_dir = output_dir - cls.output_queue = output_queue - cls.redis_port = redis_port - if cls.redis_port not in cls._instances: - cls._instances[redis_port] = super().__new__(cls) - # these args will only be passed by slips.py - # the rest of the modules can create an obj of this class without these args, - # and will get the same obj instatiated by slips.py - # cls._instances[redis_port] = super().__new__(cls) - cls._obj = super().__new__(DBManager) - # in some rare cases we don't wanna start sqlite, - # like when using -S - # we just want to connect to redis to get the PIDs - cls.sqlite = None - if start_sqlite: - cls.sqlite = SQLiteDB(output_dir) - - cls.rdb = RedisDB(redis_port, output_queue, **kwargs) - return cls._instances[redis_port] + self.output_dir = output_dir + self.output_queue = output_queue + self.redis_port = redis_port + + self.rdb = RedisDB(redis_port, output_queue, **kwargs) + + # in some rare cases we don't wanna start sqlite, + # like when using -S + # we just want to connect to redis to get the PIDs + self.sqlite = None + if start_sqlite: + self.sqlite = self.create_sqlite_db(output_dir, output_queue) + + + def create_sqlite_db(self, output_dir, output_queue): + return SQLiteDB(output_dir, output_queue) + + @classmethod + def read_configuration(cls): + conf = ConfigParser() + cls.width = conf.get_tw_width_as_float() def get_sqlite_db_path(self) -> str: return self.sqlite.get_db_path() @@ -857,6 +853,12 @@ def get_commit(self, *args, **kwargs): def get_branch(self, *args, **kwargs): return self.rdb.get_branch(*args, **kwargs) + def add_alert(self, alert: dict): + twid_starttime: float = self.rdb.getTimeTW(alert['profileid'], alert['twid']) + twid_endtime: float = twid_starttime + RedisDB.width + alert.update({'tw_start': twid_starttime, 'tw_end': twid_endtime}) + return self.sqlite.add_alert(alert) + def close(self, *args, **kwargs): self.rdb.r.close() self.rdb.rcache.close() diff --git a/slips_files/core/database/redis_db/database.py b/slips_files/core/database/redis_db/database.py index fed59dc32..e8de48237 100644 --- a/slips_files/core/database/redis_db/database.py +++ b/slips_files/core/database/redis_db/database.py @@ -70,6 +70,9 @@ class RedisDB(IoCHandler, AlertHandler, ProfileHandler): 'check_jarm_hash', 'control_channel', 'new_module_flow' + 'control_module', + 'cpu_profile', + 'memory_profile' } # The name is used to print in the outputprocess name = 'DB' diff --git a/slips_files/core/database/sqlite_db/database.py b/slips_files/core/database/sqlite_db/database.py index 8ac340f3f..72a13b49e 100644 --- a/slips_files/core/database/sqlite_db/database.py +++ b/slips_files/core/database/sqlite_db/database.py @@ -8,38 +8,83 @@ class SQLiteDB(): """Stores all the flows slips reads and handles labeling them""" - _obj = None + name = "SQLiteDB" # used to lock each call to commit() cursor_lock = Lock() + trial = 0 - def __new__(cls, output_dir): - # To treat the db as a singelton - if cls._obj is None or not isinstance(cls._obj, cls): - cls._obj = super(SQLiteDB, cls).__new__(SQLiteDB) - cls._flows_db = os.path.join(output_dir, 'flows.sqlite') - cls._init_db() - cls.conn = sqlite3.connect(cls._flows_db, check_same_thread=False) - cls.cursor = cls.conn.cursor() - cls.init_tables() - return cls._obj + def __init__(self, output_dir, output_queue): + self.output_queue = output_queue + self._flows_db = os.path.join(output_dir, 'flows.sqlite') + self.connect() + def connect(self): + """ + Creates the db if it doesn't exist and connects to it + """ + db_newly_created = False + if not os.path.exists(self._flows_db): + # db not created, mark it as first time accessing it so we can init tables once we connect + db_newly_created = True + self._init_db() + + self.conn = sqlite3.connect(self._flows_db, check_same_thread=False, timeout=20) - @classmethod - def init_tables(cls): + self.cursor = self.conn.cursor() + if db_newly_created: + # only init tables if the db is newly created + self.init_tables() + + def get_number_of_tables(self): + """ + returns the number of tables in the current db + """ + query = f"SELECT count(*) FROM sqlite_master WHERE type='table';" + self.execute(query) + x = self.fetchone() + return x[0] + + def init_tables(self): """creates the tables we're gonna use""" table_schema = { - 'flows': "uid TEXT PRIMARY KEY, flow TEXT, label TEXT, profileid TEXT, twid TEXT", - 'altflows': "uid TEXT PRIMARY KEY, flow TEXT, label TEXT, profileid TEXT, twid TEXT, flow_type TEXT" + 'flows': "uid TEXT PRIMARY KEY, flow TEXT, label TEXT, profileid TEXT, twid TEXT, aid TEXT", + 'altflows': "uid TEXT PRIMARY KEY, flow TEXT, label TEXT, profileid TEXT, twid TEXT, flow_type TEXT", + 'alerts': 'alert_id TEXT PRIMARY KEY, alert_time TEXT, ip_alerted TEXT, timewindow TEXT, tw_start TEXT, tw_end TEXT, label TEXT' } for table_name, schema in table_schema.items(): - cls.create_table(table_name, schema) + self.create_table(table_name, schema) - @classmethod - def _init_db(cls): + def _init_db(self): """ creates the db if it doesn't exist and clears it if it exists """ - open(cls._flows_db,'w').close() + open(self._flows_db,'w').close() + + def create_table(self, table_name, schema): + query = f"CREATE TABLE IF NOT EXISTS {table_name} ({schema})" + self.execute(query) + + def print(self, text, verbose=1, debug=0): + """ + Function to use to print text using the outputqueue of slips. + Slips then decides how, when and where to print this text by taking all the processes into account + :param verbose: + 0 - don't print + 1 - basic operation/proof of work + 2 - log I/O operations and filenames + 3 - log database/profile/timewindow changes + :param debug: + 0 - don't print + 1 - print exceptions + 2 - unsupported and unhandled types (cases that may cause errors) + 3 - red warnings that needs examination - developer warnings + :param text: text to print. Can include format like 'Test {}'.format('here') + """ + levels = f'{verbose}{debug}' + try: + self.output_queue.put(f'{levels}|{self.name}|{text}') + except AttributeError: + pass def get_db_path(self) -> str: """ @@ -47,11 +92,6 @@ def get_db_path(self) -> str: """ return self._flows_db - @classmethod - def create_table(cls, table_name, schema): - query = f"CREATE TABLE IF NOT EXISTS {table_name} ({schema})" - cls.cursor.execute(query) - cls.conn.commit() def get_altflow_from_uid(self, profileid, twid, uid) -> dict: """ Given a uid, get the alternative flow associated with it """ @@ -207,14 +247,21 @@ def get_flow(self, uid: str, twid=False) -> dict: def add_flow( self, flow, profileid: str, twid:str, label='benign' ): - - parameters = (profileid, twid, flow.uid, json.dumps(asdict(flow)), label) - - self.execute( - 'INSERT OR REPLACE INTO flows (profileid, twid, uid, flow, label) ' - 'VALUES (?, ?, ?, ?, ?);', - parameters, - ) + if hasattr(flow, 'aid'): + parameters = (profileid, twid, flow.uid, json.dumps(asdict(flow)), label, flow.aid) + self.execute( + 'INSERT OR REPLACE INTO flows (profileid, twid, uid, flow, label, aid) ' + 'VALUES (?, ?, ?, ?, ?, ?);', + parameters, + ) + else: + parameters = (profileid, twid, flow.uid, json.dumps(asdict(flow)), label) + + self.execute( + 'INSERT OR REPLACE INTO flows (profileid, twid, uid, flow, label) ' + 'VALUES (?, ?, ?, ?, ?);', + parameters, + ) def get_flows_count(self, profileid, twid) -> int: """ @@ -237,6 +284,24 @@ def add_altflow( parameters, ) + def add_alert(self, alert: dict): + """ + adds an alert to the alerts table + alert param should contain alert_id, alert_ts, ip_alerted, twid, tw_start, tw_end, label + """ + # 'alerts': 'alert_id TEXT PRIMARY KEY, alert_time TEXT, ip_alerted TEXT, timewindow TEXT, tw_start TEXT, tw_end TEXT, label TEXT' + self.execute( + 'INSERT OR REPLACE INTO alerts (alert_id, ip_alerted, timewindow, tw_start, tw_end, label, alert_time) ' + 'VALUES (?, ?, ?, ?, ?, ?, ?);', + (alert['alert_ID'], + alert['profileid'].split()[-1], + alert['twid'], + alert['tw_start'], + alert['tw_end'], + alert['label'], + alert['time_detected']) + ) + def insert(self, table_name, values): @@ -305,26 +370,46 @@ def execute(self, query, params=None): since sqlite is terrible with multi-process applications this should be used instead of all calls to commit() and execute() """ - try: self.cursor_lock.acquire(True) + #start a transaction + self.cursor.execute('BEGIN') if not params: self.cursor.execute(query) else: self.cursor.execute(query, params) + self.conn.commit() self.cursor_lock.release() + # counter for the number of times we tried executing a tx and failed + self.trial = 0 + except sqlite3.Error as e: - if "database is locked" in str(e): - self.cursor_lock.release() + self.cursor_lock.release() + self.conn.rollback() + if self.trial >= 2: + # tried 2 times to exec a query and it's still failing + self.trial = 0 + # discard query + self.print(f"Error executing query: {query} - {e}. Query discarded", 0, 1) + + elif "database is locked" in str(e): + # keep track of failed trials + self.trial += 1 + # Retry after a short delay - sleep(0.1) + sleep(5) self.execute(query, params=params) else: # An error occurred during execution - print(f"Error executing query ({query}): {e}") + self.conn.rollback() + # print(f"Re-trying to execute query ({query}). reason: {e}") + # keep track of failed trials + self.trial += 1 + self.execute(query, params=params) + \ No newline at end of file diff --git a/slips_files/core/evidenceProcess.py b/slips_files/core/evidenceProcess.py index 93a412443..94e600de5 100644 --- a/slips_files/core/evidenceProcess.py +++ b/slips_files/core/evidenceProcess.py @@ -536,6 +536,33 @@ def label_flows_causing_alert(self): uids: list = self.db.get_flows_causing_evidence(evidence_id) self.db.set_flow_label(uids, 'malicious') + def handle_new_alert(self, alert_ID: str, tw_evidence: dict): + """ + saves alert details in the db and informs exporting modules about it + """ + profile, srcip, twid, _ = alert_ID.split('_') + profileid = f'{profile}_{srcip}' + self.db.set_evidence_causing_alert( + profileid, + twid, + alert_ID, + self.IDs_causing_an_alert + ) + alert_details = { + 'alert_ID': alert_ID, + 'profileid': profileid, + 'twid': twid, + } + self.db.publish('new_alert', json.dumps(alert_details)) + #store the alerts in the alerts table + alert_details.update( + {'time_detected': utils.convert_format(datetime.now(), 'unixtimestamp'), + 'label': 'malicious'}) + self.db.add_alert(alert_details) + self.label_flows_causing_alert() + self.send_to_exporting_module(tw_evidence) + + def main(self): while not self.should_stop(): if msg := self.get_msg('evidence_added'): @@ -646,20 +673,8 @@ def main(self): # store the alert in our database # the alert ID is profileid_twid + the ID of the last evidence causing this alert alert_ID = f'{profileid}_{twid}_{ID}' - self.db.set_evidence_causing_alert( - profileid, - twid, - alert_ID, - self.IDs_causing_an_alert - ) - to_send = { - 'alert_ID': alert_ID, - 'profileid': profileid, - 'twid': twid, - } - self.db.publish('new_alert', json.dumps(to_send)) - self.label_flows_causing_alert() - self.send_to_exporting_module(tw_evidence) + + self.handle_new_alert(alert_ID, tw_evidence) # print the alert alert_to_print = ( diff --git a/slips_files/core/flows/zeek.py b/slips_files/core/flows/zeek.py index f9b16cb28..20681addb 100644 --- a/slips_files/core/flows/zeek.py +++ b/slips_files/core/flows/zeek.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import List from datetime import datetime, timedelta +from slips_files.common.slips_utils import utils import json @dataclass @@ -32,7 +33,6 @@ class Conn: state: str history: str - type_: str = 'conn' dir_: str = '->' @@ -42,6 +42,8 @@ def __post_init__(self) -> None: self.pkts: int = self.spkts + self.dpkts self.bytes: int = self.sbytes + self.dbytes self.state_hist: str = self.history or self.state + # community IDs are for conn.log flows only + self.aid = utils.get_aid(self) @dataclass class DNS: diff --git a/slips_files/core/helpers/whitelist.py b/slips_files/core/helpers/whitelist.py index 4a2f009ad..db99622c3 100644 --- a/slips_files/core/helpers/whitelist.py +++ b/slips_files/core/helpers/whitelist.py @@ -176,12 +176,16 @@ def is_whitelisted_flow(self, flow) -> bool: domains_to_check.append(flow.subject.replace( 'CN=', '' )) + elif flow_type == 'dns': + domains_to_check.append(flow.query) + for domain in domains_to_check: if self.is_whitelisted_domain(domain, saddr, daddr, 'flows'): return True + if whitelisted_IPs := self.db.get_whitelist('IPs'): # self.print('Check the IPs') # Check if the IPs are whitelisted @@ -207,6 +211,18 @@ def is_whitelisted_flow(self, flow) -> bool: # self.print(f"Whitelisting the dst IP {column_values['daddr']}") return True + if flow_type == 'dns': + # check all answers + for answer in flow.answers: + if answer in ips_to_whitelist: + # #TODO the direction doesn't matter here right? + # direction = whitelisted_IPs[daddr]['from'] + what_to_ignore = whitelisted_IPs[answer]['what_to_ignore'] + if self.should_ignore_flows(what_to_ignore): + # self.print(f"Whitelisting the IP {answer} due to its presence in a dns answer") + return True + + if whitelisted_macs := self.db.get_whitelist('mac'): # try to get the mac address of the current flow src_mac = flow.smac if hasattr(flow, 'smac') else False diff --git a/slips_files/core/inputProcess.py b/slips_files/core/inputProcess.py index 5f93b5d2c..7f2601503 100644 --- a/slips_files/core/inputProcess.py +++ b/slips_files/core/inputProcess.py @@ -804,7 +804,7 @@ def detach_child(): stderr=subprocess.PIPE, stdin=subprocess.PIPE, cwd=self.zeek_dir, - preexec_fn=detach_child, + start_new_session=True ) # you have to get the pid before communicate() self.zeek_pid = zeek.pid diff --git a/tests/integration_tests/test.conf b/tests/integration_tests/test.conf index 14b960346..04205d5f2 100644 --- a/tests/integration_tests/test.conf +++ b/tests/integration_tests/test.conf @@ -337,6 +337,33 @@ disabled_detections = [ConnectionWithoutDNS] # the purpose of using them is to change the ownership of the docker created files to be able to rwx the files from outside docker too, for example the files in the output/ dir UID = 0 GID = 0 + +#################### + +[Profiling] + +# [11] CPU profiling + +# enable cpu profiling [yes,no] +cpu_profiler_enable = no + +# Available options are [dev,live] +# dev for deterministic profiling. this will give precise information about the CPU usage +# throughout the program runtime. This module cannot give live updates +# live mode is for sampling data stream. To track the function stack in real time. it is accessible from web interface +cpu_profiler_mode = dev + +# profile all subprocesses in dev mode [yes,no]. +cpu_profiler_multiprocess = yes + +# set number of tracer entries (dev mode only) +cpu_profiler_dev_mode_entries = 1000000 + +# set maximum output lines (live mode only) +cpu_profiler_output_limit = 20 + +# set the wait time between sampling sequences in seconds (live mode only) +cpu_profiler_sampling_interval = 20 #################### # [10] enable or disable p2p for slips [P2P] diff --git a/tests/integration_tests/test2.conf b/tests/integration_tests/test2.conf index 0c237003b..38ef7e6dd 100644 --- a/tests/integration_tests/test2.conf +++ b/tests/integration_tests/test2.conf @@ -342,6 +342,34 @@ disabled_detections = [] # the purpose of using them is to change the ownership of the docker created files to be able to rwx the files from outside docker too, for example the files in the output/ dir UID = 0 GID = 0 + +#################### + +[Profiling] + +# [11] CPU profiling + +# enable cpu profiling [yes,no] +cpu_profiler_enable = no + +# Available options are [dev,live] +# dev for deterministic profiling. this will give precise information about the CPU usage +# throughout the program runtime. This module cannot give live updates +# live mode is for sampling data stream. To track the function stack in real time. it is accessible from web interface +cpu_profiler_mode = dev + +# profile all subprocesses in dev mode [yes,no]. +cpu_profiler_multiprocess = yes + +# set number of tracer entries (dev mode only) +cpu_profiler_dev_mode_entries = 1000000 + +# set maximum output lines (live mode only) +cpu_profiler_output_limit = 20 + +# set the wait time between sampling sequences in seconds (live mode only) +cpu_profiler_sampling_interval = 20 + #################### # [11] enable or disable p2p for slips [P2P] diff --git a/tests/module_factory.py b/tests/module_factory.py index 2237d9c79..f1e15267f 100644 --- a/tests/module_factory.py +++ b/tests/module_factory.py @@ -21,9 +21,11 @@ from multiprocessing import Queue, Event from modules.arp.arp import ARP import shutil +from unittest.mock import patch, Mock, MagicMock import os + def read_configuration(): return @@ -67,14 +69,20 @@ def create_main_obj(self, input_information): return main - def create_http_analyzer_obj(self, mock_db): - http_analyzer = HTTPAnalyzer(self.output_queue, mock_db, self.dummy_termination_event) + def create_http_analyzer_obj(self, mock_rdb): + with patch.object(DBManager, 'create_sqlite_db', return_value=Mock()): + http_analyzer = HTTPAnalyzer(self.output_queue, 'dummy_output_dir', 6379, self.dummy_termination_event) + http_analyzer.db.rdb = mock_rdb + # override the self.print function to avoid broken pipes http_analyzer.print = do_nothing return http_analyzer - def create_virustotal_obj(self, mock_db): - virustotal = VT(self.output_queue, mock_db, self.dummy_termination_event) + def create_virustotal_obj(self, mock_rdb): + with patch.object(DBManager, 'create_sqlite_db', return_value=Mock()): + virustotal = VT(self.output_queue, 'dummy_output_dir', 6379, self.dummy_termination_event) + virustotal.db.rdb = mock_rdb + # override the self.print function to avoid broken pipes virustotal.print = do_nothing virustotal.__read_configuration = read_configuration @@ -83,43 +91,52 @@ def create_virustotal_obj(self, mock_db): ) return virustotal - def create_arp_obj(self, mock_db): - arp = ARP(self.output_queue, mock_db, self.dummy_termination_event) + def create_arp_obj(self, mock_rdb): + with patch.object(DBManager, 'create_sqlite_db', return_value=Mock()): + arp = ARP(self.output_queue, 'dummy_output_dir', 6379, self.dummy_termination_event) + arp.db.rdb = mock_rdb # override the self.print function to avoid broken pipes arp.print = do_nothing return arp - def create_blocking_obj(self, mock_db): - blocking = Blocking(self.output_queue, mock_db, self.dummy_termination_event) + def create_blocking_obj(self, mock_rdb): + with patch.object(DBManager, 'create_sqlite_db', return_value=Mock()): + blocking = Blocking(self.output_queue, 'dummy_output_dir', 6379, self.dummy_termination_event) + blocking.db.rdb = mock_rdb + # override the print function to avoid broken pipes blocking.print = do_nothing return blocking - def create_flowalerts_obj(self, mock_db): - flowalerts = FlowAlerts(self.output_queue, mock_db, self.dummy_termination_event) + def create_flowalerts_obj(self, mock_rdb): + with patch.object(DBManager, 'create_sqlite_db', return_value=Mock()): + flowalerts = FlowAlerts(self.output_queue, 'dummy_output_dir', 6379, self.dummy_termination_event) + flowalerts.db.rdb = mock_rdb + # override the self.print function to avoid broken pipes flowalerts.print = do_nothing return flowalerts def create_inputProcess_obj( - self, input_information, input_type, mock_db + self, input_information, input_type, mock_rdb ): zeek_tmp_dir = os.path.join(os.getcwd(), 'zeek_dir_for_testing' ) - - inputProcess = InputProcess( - mock_db, - self.output_queue, - 'output/', - self.dummy_termination_event, - profiler_queue=self.profiler_queue, - input_type=input_type, - input_information=input_information, - cli_packet_filter= None, - zeek_or_bro=check_zeek_or_bro(), - zeek_dir=zeek_tmp_dir, - line_type=False, - ) + with patch.object(DBManager, 'create_sqlite_db', return_value=Mock()): + inputProcess = InputProcess( + self.output_queue, + 'dummy_output_dir', 6379, + # 'output/', + self.dummy_termination_event, + profiler_queue=self.profiler_queue, + input_type=input_type, + input_information=input_information, + cli_packet_filter= None, + zeek_or_bro=check_zeek_or_bro(), + zeek_dir=zeek_tmp_dir, + line_type=False, + ) + inputProcess.db.rdb = mock_rdb inputProcess.bro_timeout = 1 # override the print function to avoid broken pipes @@ -130,8 +147,10 @@ def create_inputProcess_obj( return inputProcess - def create_ip_info_obj(self, db): - ip_info = IPInfo(self.output_queue, db, self.dummy_termination_event) + def create_ip_info_obj(self, mock_rdb): + with patch.object(DBManager, 'create_sqlite_db', return_value=Mock()): + ip_info = IPInfo(self.output_queue, 'dummy_output_dir', 6379, self.dummy_termination_event) + ip_info.db.rdb = mock_rdb # override the self.print function to avoid broken pipes ip_info.print = do_nothing return ip_info @@ -139,13 +158,15 @@ def create_ip_info_obj(self, db): def create_asn_obj(self, db): return ASN(db) - def create_leak_detector_obj(self, mock_db): + def create_leak_detector_obj(self, mock_rdb): # this file will be used for storing the module output # and deleted when the tests are done test_pcap = 'dataset/test7-malicious.pcap' yara_rules_path = 'tests/yara_rules_for_testing/rules/' compiled_yara_rules_path = 'tests/yara_rules_for_testing/compiled/' - leak_detector = LeakDetector(self.output_queue, mock_db, self.dummy_termination_event) + with patch.object(DBManager, 'create_sqlite_db', return_value=Mock()): + leak_detector = LeakDetector(self.output_queue, 'dummy_output_dir', 6379, self.dummy_termination_event) + leak_detector.db.rdb = mock_rdb # override the self.print function to avoid broken pipes leak_detector.print = do_nothing # this is the path containing 1 yara rule for testing, it matches every pcap @@ -155,11 +176,10 @@ def create_leak_detector_obj(self, mock_db): return leak_detector - def create_profilerProcess_obj(self, mock_db): + def create_profilerProcess_obj(self): profilerProcess = ProfilerProcess( - mock_db, self.output_queue, - 'output/', + 'output/', 6377, self.dummy_termination_event, profiler_queue=self.input_queue, ) @@ -178,20 +198,29 @@ def create_process_manager_obj(self): def create_utils_obj(self): return utils - def create_threatintel_obj(self, mock_db): - threatintel = ThreatIntel(self.output_queue, mock_db, self.dummy_termination_event) + def create_threatintel_obj(self, mock_rdb): + with patch.object(DBManager, 'create_sqlite_db', return_value=Mock()): + threatintel = ThreatIntel(self.output_queue, 'dummy_output_dir', 6379, self.dummy_termination_event) + threatintel.db.rdb = mock_rdb + # override the self.print function to avoid broken pipes threatintel.print = do_nothing return threatintel - def create_update_manager_obj(self, mock_db): - update_manager = UpdateManager(self.output_queue, mock_db, self.dummy_termination_event) + def create_update_manager_obj(self, mock_rdb): + with patch.object(DBManager, 'create_sqlite_db', return_value=Mock()): + update_manager = UpdateManager(self.output_queue, 'dummy_output_dir', 6379, self.dummy_termination_event) + update_manager.db.rdb = mock_rdb + # override the self.print function to avoid broken pipes update_manager.print = do_nothing return update_manager - def create_whitelist_obj(self, mock_db): - whitelist = Whitelist(self.output_queue, mock_db) + def create_whitelist_obj(self, mock_rdb): + with patch.object(DBManager, 'create_sqlite_db', return_value=Mock()): + whitelist = Whitelist(self.output_queue, mock_rdb) + whitelist.db.rdb = mock_rdb + # override the self.print function to avoid broken pipes whitelist.print = do_nothing whitelist.whitelist_path = 'tests/test_whitelist.conf' diff --git a/tests/test_arp.py b/tests/test_arp.py index 70de1abe6..eb85c66ec 100644 --- a/tests/test_arp.py +++ b/tests/test_arp.py @@ -9,8 +9,8 @@ # check_arp_scan is tested in test_dataset.py, check arp-only unit test -def test_check_dstip_outside_localnet(mock_db): - ARP = ModuleFactory().create_arp_obj(mock_db) +def test_check_dstip_outside_localnet(mock_rdb): + ARP = ModuleFactory().create_arp_obj(mock_rdb) daddr = '1.1.1.1' uid = '1234' saddr = '192.168.1.1' @@ -20,8 +20,8 @@ def test_check_dstip_outside_localnet(mock_db): ) -def test_detect_unsolicited_arp(mock_db): - ARP = ModuleFactory().create_arp_obj(mock_db) +def test_detect_unsolicited_arp(mock_rdb): + ARP = ModuleFactory().create_arp_obj(mock_rdb) uid = '1234' ts = '1632214645.783595' dst_mac = 'ff:ff:ff:ff:ff:ff' @@ -33,8 +33,8 @@ def test_detect_unsolicited_arp(mock_db): ) -def test_detect_MITM_ARP_attack(mock_db): - ARP = ModuleFactory().create_arp_obj(mock_db) +def test_detect_MITM_ARP_attack(mock_rdb): + ARP = ModuleFactory().create_arp_obj(mock_rdb) # add a mac addr to this profile src_mac = '2e:a4:18:f8:3d:02' @@ -42,7 +42,7 @@ def test_detect_MITM_ARP_attack(mock_db): uid = '1234' ts = '1636305825.755132' saddr = '192.168.1.3' - mock_db.get_ip_of_mac.return_value = json.dumps([profileid]) + mock_rdb.get_ip_of_mac.return_value = json.dumps([profileid]) assert ( ARP.detect_MITM_ARP_attack( profileid, diff --git a/tests/test_flowalerts.py b/tests/test_flowalerts.py index 617b0f402..ace67e857 100644 --- a/tests/test_flowalerts.py +++ b/tests/test_flowalerts.py @@ -14,8 +14,8 @@ dst_profileid = f'profile_{daddr}' -def test_port_belongs_to_an_org(mock_db): - flowalerts = ModuleFactory().create_flowalerts_obj(mock_db) +def test_port_belongs_to_an_org(mock_rdb): + flowalerts = ModuleFactory().create_flowalerts_obj(mock_rdb) # belongs to apple portproto = '65509/tcp' @@ -23,24 +23,24 @@ def test_port_belongs_to_an_org(mock_db): # mock the db response to say that the org of this port # is apple and the mac vendor of the # given profile is also apple - mock_db.get_organization_of_port.return_value = json.dumps( + mock_rdb.get_organization_of_port.return_value = json.dumps( {'ip':[], 'org_name':'apple'} ) - mock_db.get_mac_vendor_from_profile.return_value = 'apple' + mock_rdb.get_mac_vendor_from_profile.return_value = 'apple' assert flowalerts.port_belongs_to_an_org(daddr, portproto, profileid) is True # doesn't belong to any org portproto = '78965/tcp' # expectations - mock_db.get_organization_of_port.return_value = None + mock_rdb.get_organization_of_port.return_value = None assert flowalerts.port_belongs_to_an_org(daddr, portproto, profileid) is False -def test_check_unknown_port(mocker, mock_db): - flowalerts = ModuleFactory().create_flowalerts_obj(mock_db) +def test_check_unknown_port(mocker, mock_rdb): + flowalerts = ModuleFactory().create_flowalerts_obj(mock_rdb) # database.set_port_info('23/udp', 'telnet') - mock_db.get_port_info.return_value = 'telnet' + mock_rdb.get_port_info.return_value = 'telnet' # now we have info 23 udp assert flowalerts.check_unknown_port( '23', @@ -54,8 +54,8 @@ def test_check_unknown_port(mocker, mock_db): ) is False # test when the port is unknown - mock_db.get_port_info.return_value = None - mock_db.is_ftp_port.return_value = False + mock_rdb.get_port_info.return_value = None + mock_rdb.is_ftp_port.return_value = False # mock the flowalerts call to port_belongs_to_an_org flowalerts_mock = mocker.patch("modules.flowalerts.flowalerts.FlowAlerts.port_belongs_to_an_org") flowalerts_mock.return_value = False @@ -74,18 +74,18 @@ def test_check_unknown_port(mocker, mock_db): def test_check_if_resolution_was_made_by_different_version( - mock_db + mock_rdb ): - flowalerts = ModuleFactory().create_flowalerts_obj(mock_db) + flowalerts = ModuleFactory().create_flowalerts_obj(mock_rdb) # now this ipv6 belongs to the same profileid, is supposed to be # the other version of the ipv4 of the used profileid - mock_db.get_the_other_ip_version.return_value = json.dumps( + mock_rdb.get_the_other_ip_version.return_value = json.dumps( '2001:0db8:85a3:0000:0000:8a2e:0370:7334' ) # now the daddr given to check_if_resolution_was_made_by_different_version() # is supposed to be resolved by the ipv6 of the profile, not th eipv4 - mock_db.get_dns_resolution.return_value = { + mock_rdb.get_dns_resolution.return_value = { 'resolved-by': '2001:0db8:85a3:0000:0000:8a2e:0370:7334' } @@ -96,10 +96,10 @@ def test_check_if_resolution_was_made_by_different_version( ) is True # check the case when the resolution wasn't done by another IP - mock_db.get_the_other_ip_version.return_value = json.dumps( + mock_rdb.get_the_other_ip_version.return_value = json.dumps( '2001:0db8:85a3:0000:0000:8a2e:0370:7334' ) - mock_db.get_dns_resolution.return_value = {'resolved-by': []} + mock_rdb.get_dns_resolution.return_value = {'resolved-by': []} assert flowalerts.check_if_resolution_was_made_by_different_version( profileid, '2.3.4.5' @@ -107,8 +107,8 @@ def test_check_if_resolution_was_made_by_different_version( -def test_check_dns_arpa_scan(mock_db): - flowalerts = ModuleFactory().create_flowalerts_obj(mock_db) +def test_check_dns_arpa_scan(mock_rdb): + flowalerts = ModuleFactory().create_flowalerts_obj(mock_rdb) # make 10 different arpa scans for ts in arange(0, 1, 1 / 10): is_arpa_scan = flowalerts.check_dns_arpa_scan( @@ -118,18 +118,18 @@ def test_check_dns_arpa_scan(mock_db): assert is_arpa_scan is True -def test_check_multiple_ssh_versions(mock_db): - flowalerts = ModuleFactory().create_flowalerts_obj(mock_db) +def test_check_multiple_ssh_versions(mock_rdb): + flowalerts = ModuleFactory().create_flowalerts_obj(mock_rdb) # in the first flow, we only have 1 use ssh client so no version incompatibility - mock_db.get_software_from_profile.return_value = {'SSH::CLIENT': {'version-major': 8, 'version-minor': 1, 'uid': 'YTYwNjBiMjIxZDkzOWYyYTc4'}} + mock_rdb.get_software_from_profile.return_value = {'SSH::CLIENT': {'version-major': 8, 'version-minor': 1, 'uid': 'YTYwNjBiMjIxZDkzOWYyYTc4'}} flow2 = {'starttime': 1632302619.444328, 'uid': 'M2VhNTA3ZmZiYjU3OGMxMzJk', 'saddr': '192.168.1.247', 'daddr': '', 'software': 'SSH::CLIENT', 'unparsed_version': 'OpenSSH_9.1', 'version_major': 9, 'version_minor': 1, 'type_': 'software'} # in flow 2 slips should detect a client version change assert flowalerts.check_multiple_ssh_versions(flow2, 'timewindow1') is True -def test_detect_DGA(mock_db): - flowalerts = ModuleFactory().create_flowalerts_obj(mock_db) +def test_detect_DGA(mock_rdb): + flowalerts = ModuleFactory().create_flowalerts_obj(mock_rdb) rcode_name = 'NXDOMAIN' # arbitrary ip to be able to call detect_DGA daddr = '10.0.0.1' @@ -140,18 +140,18 @@ def test_detect_DGA(mock_db): assert dga_detected is True -def test_detect_young_domains(mock_db): - flowalerts = ModuleFactory().create_flowalerts_obj(mock_db) +def test_detect_young_domains(mock_rdb): + flowalerts = ModuleFactory().create_flowalerts_obj(mock_rdb) domain = 'example.com' # age in days - mock_db.getDomainData.return_value = {'Age': 50} + mock_rdb.getDomainData.return_value = {'Age': 50} assert ( flowalerts.detect_young_domains(domain, timestamp, profileid, twid, uid) is True ) # more than the age threshold - mock_db.getDomainData.return_value = {'Age': 1000} + mock_rdb.getDomainData.return_value = {'Age': 1000} assert ( flowalerts.detect_young_domains(domain, timestamp, profileid, twid, uid) is False ) diff --git a/tests/test_http_analyzer.py b/tests/test_http_analyzer.py index 7eacf11f3..4296867b0 100644 --- a/tests/test_http_analyzer.py +++ b/tests/test_http_analyzer.py @@ -20,8 +20,8 @@ def get_random_MAC(): -def test_check_suspicious_user_agents(mock_db): - http_analyzer = ModuleFactory().create_http_analyzer_obj(mock_db) +def test_check_suspicious_user_agents(mock_rdb): + http_analyzer = ModuleFactory().create_http_analyzer_obj(mock_rdb) # create a flow with suspicious user agent host = '147.32.80.7' uri = '/wpad.dat' @@ -31,8 +31,8 @@ def test_check_suspicious_user_agents(mock_db): ) -def test_check_multiple_google_connections(mock_db): - http_analyzer = ModuleFactory().create_http_analyzer_obj(mock_db) +def test_check_multiple_google_connections(mock_rdb): + http_analyzer = ModuleFactory().create_http_analyzer_obj(mock_rdb) # {"ts":1635765765.435485,"uid":"C7mv0u4M1zqJBHydgj", # "id.orig_h":"192.168.1.28","id.orig_p":52102,"id.resp_h":"216.58.198.78", # "id.resp_p":80,"trans_depth":1,"method":"GET","host":"google.com","uri":"/", @@ -48,16 +48,16 @@ def test_check_multiple_google_connections(mock_db): ) assert found_detection is True -def test_parsing_online_ua_info(mock_db, mocker): +def test_parsing_online_ua_info(mock_rdb, mocker): """ tests the parsing and processing the ua found by the online query """ - http_analyzer = ModuleFactory().create_http_analyzer_obj(mock_db) + http_analyzer = ModuleFactory().create_http_analyzer_obj(mock_rdb) # use a different profile for this unit test to make sure we don't already have info about # it in the db profileid = 'profile_192.168.99.99' - mock_db.get_user_agent_from_profile.return_value = None + mock_rdb.get_user_agent_from_profile.return_value = None # mock the function that gets info about the given ua from an online db mock_requests = mocker.patch("requests.get") mock_requests.return_value.status_code = 200 @@ -73,8 +73,8 @@ def test_parsing_online_ua_info(mock_db, mocker): assert ua_info['browser'] == 'Safari' -def test_get_user_agent_info(mock_db, mocker): - http_analyzer = ModuleFactory().create_http_analyzer_obj(mock_db) +def test_get_user_agent_info(mock_rdb, mocker): + http_analyzer = ModuleFactory().create_http_analyzer_obj(mock_rdb) # mock the function that gets info about the # given ua from an online db: get_ua_info_online() mock_requests = mocker.patch("requests.get") @@ -85,8 +85,8 @@ def test_get_user_agent_info(mock_db, mocker): "os_name":"OS X" }""" - mock_db.add_all_user_agent_to_profile.return_value = True - mock_db.get_user_agent_from_profile.return_value = None + mock_rdb.add_all_user_agent_to_profile.return_value = True + mock_rdb.get_user_agent_from_profile.return_value = None expected_ret_value = {'browser': 'Safari', 'os_name': 'OS X', @@ -98,27 +98,27 @@ def test_get_user_agent_info(mock_db, mocker): # assert ua_added_to_db is not None, 'Error getting UA info online' # assert ua_added_to_db is not False, 'We already have UA info about this profile in the db' -def test_check_incompatible_user_agent(mock_db): +def test_check_incompatible_user_agent(mock_rdb): - http_analyzer = ModuleFactory().create_http_analyzer_obj(mock_db) + http_analyzer = ModuleFactory().create_http_analyzer_obj(mock_rdb) # use a different profile for this unit test to make sure we don't already have info about # it in the db. it has to be a private IP for its' MAC to not be marked as the gw MAC profileid = 'profile_192.168.77.254' # Mimic an intel mac vendor using safari - mock_db.get_mac_vendor_from_profile.return_value = 'Intel Corp' - mock_db.get_user_agent_from_profile.return_value = {'browser': 'safari'} + mock_rdb.get_mac_vendor_from_profile.return_value = 'Intel Corp' + mock_rdb.get_user_agent_from_profile.return_value = {'browser': 'safari'} assert ( http_analyzer.check_incompatible_user_agent('google.com', '/images', timestamp, profileid, twid, uid) is True ) -def test_extract_info_from_UA(mock_db): - http_analyzer = ModuleFactory().create_http_analyzer_obj(mock_db) +def test_extract_info_from_UA(mock_rdb): + http_analyzer = ModuleFactory().create_http_analyzer_obj(mock_rdb) # use another profile, because the default # one already has a ua in the db - mock_db.get_user_agent_from_profile.return_value = None + mock_rdb.get_user_agent_from_profile.return_value = None profileid = 'profile_192.168.1.2' server_bag_ua = 'server-bag[macOS,11.5.1,20G80,MacBookAir10,1]' assert ( @@ -127,8 +127,8 @@ def test_extract_info_from_UA(mock_db): ) -def test_check_multiple_UAs(mock_db): - http_analyzer = ModuleFactory().create_http_analyzer_obj(mock_db) +def test_check_multiple_UAs(mock_rdb): + http_analyzer = ModuleFactory().create_http_analyzer_obj(mock_rdb) mozilla_ua = 'Mozilla/5.0 (X11; Fedora;Linux x86; rv:60.0) Gecko/20100101 Firefox/60.0' # old ua cached_ua = {'os_type': 'Fedora', 'os_name': 'Linux'} diff --git a/tests/test_inputProc.py b/tests/test_inputProc.py index fa20c8193..b83671069 100644 --- a/tests/test_inputProc.py +++ b/tests/test_inputProc.py @@ -10,10 +10,10 @@ [('pcap', 'dataset/test12-icmp-portscan.pcap')], ) def test_handle_pcap_and_interface( - input_type, input_information, mock_db + input_type, input_information, mock_rdb ): # no need to test interfaces because in that case read_zeek_files runs in a loop and never returns - inputProcess = ModuleFactory().create_inputProcess_obj(input_information, input_type, mock_db) + inputProcess = ModuleFactory().create_inputProcess_obj(input_information, input_type, mock_rdb) inputProcess.zeek_pid = 'False' inputProcess.is_zeek_tabs = True assert inputProcess.handle_pcap_and_interface() is True @@ -29,12 +29,12 @@ def test_handle_pcap_and_interface( ], ) def test_read_zeek_folder( - input_information, mock_db + input_information, mock_rdb ): - inputProcess = ModuleFactory().create_inputProcess_obj(input_information, 'zeek_folder', mock_db) + inputProcess = ModuleFactory().create_inputProcess_obj(input_information, 'zeek_folder', mock_rdb) # no need to get the total flows in this test, skip this part - mock_db.is_growing_zeek_dir.return_value = True - mock_db.get_all_zeek_file.return_value = [os.path.join(input_information, 'conn.log')] + mock_rdb.is_growing_zeek_dir.return_value = True + mock_rdb.get_all_zeek_file.return_value = [os.path.join(input_information, 'conn.log')] assert inputProcess.read_zeek_folder() is True @@ -48,9 +48,9 @@ def test_read_zeek_folder( ], ) def test_handle_zeek_log_file( - input_information, mock_db, expected_output + input_information, mock_rdb, expected_output ): - inputProcess = ModuleFactory().create_inputProcess_obj(input_information, 'zeek_log_file', mock_db) + inputProcess = ModuleFactory().create_inputProcess_obj(input_information, 'zeek_log_file', mock_rdb) assert inputProcess.handle_zeek_log_file() == expected_output @@ -61,9 +61,9 @@ def test_handle_zeek_log_file( 'input_information', [('dataset/test1-normal.nfdump')] ) def test_handle_nfdump( - input_information, mock_db + input_information, mock_rdb ): - inputProcess = ModuleFactory().create_inputProcess_obj(input_information, 'nfdump', mock_db) + inputProcess = ModuleFactory().create_inputProcess_obj(input_information, 'nfdump', mock_rdb) assert inputProcess.handle_nfdump() is True @@ -78,9 +78,9 @@ def test_handle_nfdump( # ('binetflow','dataset/test3-mixed.binetflow'), # ('binetflow','dataset/test4-malicious.binetflow'), def test_handle_binetflow( - input_type, input_information, mock_db + input_type, input_information, mock_rdb ): - inputProcess = ModuleFactory().create_inputProcess_obj(input_information, input_type, mock_db) + inputProcess = ModuleFactory().create_inputProcess_obj(input_information, input_type, mock_rdb) assert inputProcess.handle_binetflow() is True @@ -89,7 +89,7 @@ def test_handle_binetflow( [('suricata', 'dataset/test6-malicious.suricata.json')], ) def test_handle_suricata( - input_type, input_information, mock_db + input_type, input_information, mock_rdb ): - inputProcess = ModuleFactory().create_inputProcess_obj(input_information, input_type, mock_db) + inputProcess = ModuleFactory().create_inputProcess_obj(input_information, input_type, mock_rdb) assert inputProcess.handle_suricata() is True diff --git a/tests/test_ip_info.py b/tests/test_ip_info.py index 846ee2d6a..b96ae058c 100644 --- a/tests/test_ip_info.py +++ b/tests/test_ip_info.py @@ -7,25 +7,25 @@ # ASN unit tests -def test_get_asn_info_from_geolite(mock_db): +def test_get_asn_info_from_geolite(mock_rdb): """ geolite is an offline db """ - ASN_info = ModuleFactory().create_asn_obj(mock_db) + ASN_info = ModuleFactory().create_asn_obj(mock_rdb) # check an ip that we know is in the db expected_asn_info = {'asn': {'number': 'AS7018', 'org': 'ATT-INTERNET4'}} assert ASN_info.get_asn_info_from_geolite('108.200.116.255') == expected_asn_info # test asn info not found in geolite assert ASN_info.get_asn_info_from_geolite('0.0.0.0') == {} -def test_cache_ip_range(mock_db): +def test_cache_ip_range(mock_rdb): # Patch the database object creation before it is instantiated - ASN_info = ModuleFactory().create_asn_obj(mock_db) + ASN_info = ModuleFactory().create_asn_obj(mock_rdb) assert ASN_info.cache_ip_range('8.8.8.8') == {'asn': {'number': 'AS15169', 'org': 'GOOGLE, US'}} # GEOIP unit tests -def test_get_geocountry(mock_db): - ip_info = ModuleFactory().create_ip_info_obj(mock_db) +def test_get_geocountry(mock_rdb): + ip_info = ModuleFactory().create_ip_info_obj(mock_rdb) #open the db we'll be using for this test # ip_info.wait_for_dbs() @@ -40,9 +40,9 @@ def test_get_geocountry(mock_db): 'geocountry': 'Unknown' } -def test_get_vendor(mocker, mock_db): +def test_get_vendor(mocker, mock_rdb): # make sure the mac db is download so that wai_for_dbs doesn't wait forever :'D - ip_info = ModuleFactory().create_ip_info_obj(mock_db) + ip_info = ModuleFactory().create_ip_info_obj(mock_rdb) profileid = 'profile_10.0.2.15' mac_addr = '08:00:27:7f:09:e1' @@ -51,7 +51,7 @@ def test_get_vendor(mocker, mock_db): mock_requests = mocker.patch("requests.get") mock_requests.return_value.status_code = 200 mock_requests.return_value.text = 'PCS Systemtechnik GmbH' - mock_db.get_mac_vendor_from_profile.return_value = False + mock_rdb.get_mac_vendor_from_profile.return_value = False # tries to get vendor either online or from our offline db mac_info = ip_info.get_vendor(mac_addr, profileid) diff --git a/tests/test_leak_detector.py b/tests/test_leak_detector.py index 798f3ebe1..35dff18e8 100644 --- a/tests/test_leak_detector.py +++ b/tests/test_leak_detector.py @@ -3,8 +3,8 @@ import os -def test_compile_and_save_rules(mock_db): - leak_detector = ModuleFactory().create_leak_detector_obj(mock_db) +def test_compile_and_save_rules(mock_rdb): + leak_detector = ModuleFactory().create_leak_detector_obj(mock_rdb) leak_detector.compile_and_save_rules() compiled_rules = os.listdir(leak_detector.compiled_yara_rules_path) assert 'test_rule.yara_compiled' in compiled_rules diff --git a/tests/test_profilerProcess.py b/tests/test_profilerProcess.py index 602589862..3d8b71c66 100644 --- a/tests/test_profilerProcess.py +++ b/tests/test_profilerProcess.py @@ -10,8 +10,8 @@ @pytest.mark.parametrize( 'file,expected_value', [('dataset/test6-malicious.suricata.json', 'suricata')] ) -def test_define_type_suricata(file, expected_value, mock_db): - profilerProcess = ModuleFactory().create_profilerProcess_obj(mock_db) +def test_define_type_suricata(file, expected_value, mock_rdb): + profilerProcess = ModuleFactory().create_profilerProcess_obj() with open(file) as f: while True: sample_flow = f.readline().replace('\n', '') @@ -29,8 +29,8 @@ def test_define_type_suricata(file, expected_value, mock_db): 'file,expected_value', [('dataset/test10-mixed-zeek-dir/conn.log', 'zeek-tabs')], ) -def test_define_type_zeek_tab(file, expected_value, mock_db): - profilerProcess = ModuleFactory().create_profilerProcess_obj(mock_db) +def test_define_type_zeek_tab(file, expected_value, mock_rdb): + profilerProcess = ModuleFactory().create_profilerProcess_obj() with open(file) as f: while True: sample_flow = f.readline().replace('\n', '') @@ -44,8 +44,8 @@ def test_define_type_zeek_tab(file, expected_value, mock_db): @pytest.mark.parametrize( 'file,expected_value', [('dataset/test9-mixed-zeek-dir/conn.log', 'zeek')] ) -def test_define_type_zeek_dict(file, expected_value, mock_db): - profilerProcess = ModuleFactory().create_profilerProcess_obj(mock_db) +def test_define_type_zeek_dict(file, expected_value, mock_rdb): + profilerProcess = ModuleFactory().create_profilerProcess_obj() with open(file) as f: sample_flow = f.readline().replace('\n', '') @@ -58,7 +58,7 @@ def test_define_type_zeek_dict(file, expected_value, mock_db): @pytest.mark.parametrize('nfdump_file', [('dataset/test1-normal.nfdump')]) -def test_define_type_nfdump(nfdump_file, mock_db): +def test_define_type_nfdump(nfdump_file, mock_rdb): # nfdump files aren't text files so we need to process them first command = f'nfdump -b -N -o csv -q -r {nfdump_file}' # Execute command @@ -74,8 +74,7 @@ def test_define_type_nfdump(nfdump_file, mock_db): continue line['data'] = nfdump_line break - profilerProcess = ModuleFactory().create_profilerProcess_obj(mock_db) - profilerProcess = ModuleFactory().create_profilerProcess_obj(mock_db) + profilerProcess = ModuleFactory().create_profilerProcess_obj() assert profilerProcess.define_type(line) == 'nfdump' @@ -90,7 +89,7 @@ def test_define_type_nfdump(nfdump_file, mock_db): ], ) def test_define_columns( - file, separator, expected_value, mock_db + file, separator, expected_value, mock_rdb ): # define_columns is called on header lines # line = '#fields ts uid id.orig_h id.orig_p @@ -104,7 +103,7 @@ def test_define_columns( line = f.readline() if line.startswith('#fields'): break - profilerProcess = ModuleFactory().create_profilerProcess_obj(mock_db) + profilerProcess = ModuleFactory().create_profilerProcess_obj() line = {'data': line} profilerProcess.separator = separator assert profilerProcess.define_columns(line) == expected_value @@ -134,17 +133,17 @@ def test_define_columns( ('dataset/test9-mixed-zeek-dir/http.log', 'http'), ('dataset/test9-mixed-zeek-dir/ssl.log', 'ssl'), ('dataset/test9-mixed-zeek-dir/notice.log', 'notice'), - ('dataset/test9-mixed-zeek-dir/files.log', 'files.log'), + # ('dataset/test9-mixed-zeek-dir/files.log', 'files.log'), ], ) def test_add_flow_to_profile(file, type_): - db = ModuleFactory().create_db_manager_obj(6379) - profilerProcess = ModuleFactory().create_profilerProcess_obj(db) + profilerProcess = ModuleFactory().create_profilerProcess_obj() # we're testing another functionality here profilerProcess.whitelist.is_whitelisted_flow = do_nothing # get zeek flow with open(file) as f: sample_flow = f.readline().replace('\n', '') + sample_flow = json.loads(sample_flow) sample_flow = { 'data': sample_flow, @@ -163,9 +162,9 @@ def test_add_flow_to_profile(file, type_): # make sure it's added if type_ == 'conn': - added_flow = db.get_flow(uid, twid=twid)[uid] + added_flow = profilerProcess.db.get_flow(uid, twid=twid)[uid] else: added_flow = ( - db.get_altflow_from_uid(profileid, twid, uid) is not None + profilerProcess.db.get_altflow_from_uid(profileid, twid, uid) is not None ) assert added_flow is not None diff --git a/tests/test_threat_intelligence.py b/tests/test_threat_intelligence.py index 1f11e63ff..3bb9ce6bc 100644 --- a/tests/test_threat_intelligence.py +++ b/tests/test_threat_intelligence.py @@ -5,8 +5,8 @@ -def test_parse_local_ti_file(mock_db): - threatintel = ModuleFactory().create_threatintel_obj(mock_db) +def test_parse_local_ti_file(mock_rdb): + threatintel = ModuleFactory().create_threatintel_obj(mock_rdb) local_ti_files_dir = threatintel.path_to_local_ti_files local_ti_file = os.path.join(local_ti_files_dir, 'own_malicious_iocs.csv') # this is an ip we know we have in own_maicious_iocs.csv @@ -22,7 +22,7 @@ def test_parse_local_ti_file(mock_db): ], ) def test_check_local_ti_files_for_update( - current_hash, old_hash, expected_return, mocker, mock_db + current_hash, old_hash, expected_return, mocker, mock_rdb ): """ first case the cur hash is diff from the old hash so slips should update @@ -30,12 +30,12 @@ def test_check_local_ti_files_for_update( third, cur hash is false meaning we cant get the file hash """ # since this is a clear db, then we should update the local ti file - threatintel = ModuleFactory().create_threatintel_obj(mock_db) + threatintel = ModuleFactory().create_threatintel_obj(mock_rdb) own_malicious_iocs = os.path.join(threatintel.path_to_local_ti_files, 'own_malicious_iocs.csv') mock_hash = mocker.patch("slips_files.common.slips_utils.Utils.get_hash_from_file") mock_hash.return_value = current_hash - mock_db.get_TI_file_info.return_value = {'hash': old_hash} + mock_rdb.get_TI_file_info.return_value = {'hash': old_hash} assert threatintel.should_update_local_ti_file(own_malicious_iocs) == expected_return diff --git a/tests/test_update_file_manager.py b/tests/test_update_file_manager.py index bc14318cf..d36868208 100644 --- a/tests/test_update_file_manager.py +++ b/tests/test_update_file_manager.py @@ -2,8 +2,8 @@ from tests.module_factory import ModuleFactory import json -def test_getting_header_fields(mocker, mock_db): - update_manager = ModuleFactory().create_update_manager_obj(mock_db) +def test_getting_header_fields(mocker, mock_rdb): + update_manager = ModuleFactory().create_update_manager_obj(mock_rdb) url = 'google.com/play' mock_requests = mocker.patch("requests.get") mock_requests.return_value.status_code = 200 @@ -13,20 +13,20 @@ def test_getting_header_fields(mocker, mock_db): assert update_manager.get_e_tag(response) == '1234' -def test_check_if_update_based_on_update_period(mock_db): - mock_db.get_TI_file_info.return_value = {'time': float('inf')} - update_manager = ModuleFactory().create_update_manager_obj(mock_db) +def test_check_if_update_based_on_update_period(mock_rdb): + mock_rdb.get_TI_file_info.return_value = {'time': float('inf')} + update_manager = ModuleFactory().create_update_manager_obj(mock_rdb) url = 'abc.com/x' # update period hasn't passed assert update_manager.check_if_update(url, float('inf')) is False -def test_check_if_update_based_on_e_tag(mocker, mock_db): - update_manager = ModuleFactory().create_update_manager_obj(mock_db) +def test_check_if_update_based_on_e_tag(mocker, mock_rdb): + update_manager = ModuleFactory().create_update_manager_obj(mock_rdb) # period passed, etag same etag = '1234' url = 'google.com/images' - mock_db.get_TI_file_info.return_value = {'e-tag': etag} + mock_rdb.get_TI_file_info.return_value = {'e-tag': etag} mock_requests = mocker.patch("requests.get") mock_requests.return_value.status_code = 200 @@ -38,20 +38,20 @@ def test_check_if_update_based_on_e_tag(mocker, mock_db): # period passed, etag different etag = '1111' url = 'google.com/images' - mock_db.get_TI_file_info.return_value = {'e-tag': etag} + mock_rdb.get_TI_file_info.return_value = {'e-tag': etag} mock_requests = mocker.patch("requests.get") mock_requests.return_value.status_code = 200 mock_requests.return_value.headers = {'ETag': '2222'} mock_requests.return_value.text = "" assert update_manager.check_if_update(url, float('-inf')) is True -def test_check_if_update_based_on_last_modified(database, mocker, mock_db): - update_manager = ModuleFactory().create_update_manager_obj(mock_db) +def test_check_if_update_based_on_last_modified(database, mocker, mock_rdb): + update_manager = ModuleFactory().create_update_manager_obj(mock_rdb) # period passed, no etag, last modified the same url = 'google.com/photos' - mock_db.get_TI_file_info.return_value = {'Last-Modified': 10.0} + mock_rdb.get_TI_file_info.return_value = {'Last-Modified': 10.0} mock_requests = mocker.patch("requests.get") mock_requests.return_value.status_code = 200 mock_requests.return_value.headers = {'Last-Modified': 10.0} @@ -62,7 +62,7 @@ def test_check_if_update_based_on_last_modified(database, mocker, mock_db): # period passed, no etag, last modified changed url = 'google.com/photos' - mock_db.get_TI_file_info.return_value = {'Last-Modified': 10} + mock_rdb.get_TI_file_info.return_value = {'Last-Modified': 10} mock_requests = mocker.patch("requests.get") mock_requests.return_value.status_code = 200 mock_requests.return_value.headers = {'Last-Modified': 11} diff --git a/tests/test_virustotal.py b/tests/test_virustotal.py index fd59dbd3d..4f86c04ae 100644 --- a/tests/test_virustotal.py +++ b/tests/test_virustotal.py @@ -81,16 +81,16 @@ def get_allowed(quota): @pytest.mark.dependency(name='sufficient_quota') @pytest.mark.parametrize('ip', ['8.8.8.8']) @valid_api_key -def test_interpret_rsponse(ip, mock_db): - virustotal = ModuleFactory().create_virustotal_obj(mock_db) +def test_interpret_rsponse(ip, mock_rdb): + virustotal = ModuleFactory().create_virustotal_obj(mock_rdb) response = virustotal.api_query_(ip) for ratio in virustotal.interpret_response(response): assert type(ratio) == float @pytest.mark.dependency(depends=["sufficient_quota"]) @valid_api_key -def test_get_domain_vt_data(mock_db): - virustotal = ModuleFactory().create_virustotal_obj(mock_db) +def test_get_domain_vt_data(mock_rdb): + virustotal = ModuleFactory().create_virustotal_obj(mock_rdb) assert virustotal.get_domain_vt_data('google.com') is not False diff --git a/tests/test_whitelist.py b/tests/test_whitelist.py index f2f434209..43a8d83ba 100644 --- a/tests/test_whitelist.py +++ b/tests/test_whitelist.py @@ -3,13 +3,13 @@ -def test_read_whitelist(mock_db): +def test_read_whitelist(mock_rdb): """ make sure the content of whitelists is read and stored properly uses tests/test_whitelist.conf for testing """ - whitelist = ModuleFactory().create_whitelist_obj(mock_db) - mock_db.get_whitelist.return_value = {} + whitelist = ModuleFactory().create_whitelist_obj(mock_rdb) + mock_rdb.get_whitelist.return_value = {} whitelisted_IPs, whitelisted_domains, whitelisted_orgs, whitelisted_mac = whitelist.read_whitelist() assert '91.121.83.118' in whitelisted_IPs assert 'apple.com' in whitelisted_domains @@ -17,15 +17,15 @@ def test_read_whitelist(mock_db): @pytest.mark.parametrize('org,asn', [('google', 'AS6432')]) -def test_load_org_asn(org, asn, mock_db): - whitelist = ModuleFactory().create_whitelist_obj(mock_db) +def test_load_org_asn(org, asn, mock_rdb): + whitelist = ModuleFactory().create_whitelist_obj(mock_rdb) assert whitelist.load_org_asn(org) is not False assert asn in whitelist.load_org_asn(org) @pytest.mark.parametrize('org,subnet', [('google', '216.73.80.0/20')]) -def test_load_org_IPs(org, subnet, mock_db): - whitelist = ModuleFactory().create_whitelist_obj(mock_db) +def test_load_org_IPs(org, subnet, mock_rdb): + whitelist = ModuleFactory().create_whitelist_obj(mock_rdb) assert whitelist.load_org_IPs(org) is not False # we now store subnets in a dict sorted by the first octet first_octet = subnet.split('.')[0]